mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-08 02:37:42 +08:00
refactor: async CacheProvider API + reduce public surface
- Make on_lookup/on_store async on CacheProvider ABC - Simplify CacheContext: replace cache_key + cache_key_bytes with cache_key_hash (str hex digest) - Make registry/utility functions internal (_prefix) - Trim comfy_api.latest.Caching exports to core API only - Make cache get/set async throughout caching.py hierarchy - Use asyncio.create_task for fire-and-forget on_store - Add NaN gating before provider calls in Core - Add await to 5 cache call sites in execution.py Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
0141af0786
commit
4cbe4fe4c7
@ -118,12 +118,12 @@ class Caching:
|
|||||||
from comfy_api.latest import Caching
|
from comfy_api.latest import Caching
|
||||||
|
|
||||||
class MyRedisProvider(Caching.CacheProvider):
|
class MyRedisProvider(Caching.CacheProvider):
|
||||||
def on_lookup(self, context):
|
async def on_lookup(self, context):
|
||||||
# Check Redis for cached result
|
# Check Redis for cached result
|
||||||
...
|
...
|
||||||
|
|
||||||
def on_store(self, context, value):
|
async def on_store(self, context, value):
|
||||||
# Store to Redis (can be async internally)
|
# Store to Redis
|
||||||
...
|
...
|
||||||
|
|
||||||
Caching.register_provider(MyRedisProvider())
|
Caching.register_provider(MyRedisProvider())
|
||||||
@ -135,10 +135,6 @@ class Caching:
|
|||||||
CacheValue,
|
CacheValue,
|
||||||
register_cache_provider as register_provider,
|
register_cache_provider as register_provider,
|
||||||
unregister_cache_provider as unregister_provider,
|
unregister_cache_provider as unregister_provider,
|
||||||
get_cache_providers as get_providers,
|
|
||||||
has_cache_providers as has_providers,
|
|
||||||
clear_cache_providers as clear_providers,
|
|
||||||
estimate_value_size,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -13,12 +13,12 @@ Example usage:
|
|||||||
)
|
)
|
||||||
|
|
||||||
class MyRedisProvider(CacheProvider):
|
class MyRedisProvider(CacheProvider):
|
||||||
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||||
# Check Redis/GCS for cached result
|
# Check Redis/GCS for cached result
|
||||||
...
|
...
|
||||||
|
|
||||||
def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||||
# Store to Redis/GCS (can be async internally)
|
# Store to Redis/GCS
|
||||||
...
|
...
|
||||||
|
|
||||||
register_cache_provider(MyRedisProvider())
|
register_cache_provider(MyRedisProvider())
|
||||||
@ -34,7 +34,7 @@ import math
|
|||||||
import pickle
|
import pickle
|
||||||
import threading
|
import threading
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
@ -47,8 +47,7 @@ class CacheContext:
|
|||||||
prompt_id: str # Current prompt execution ID
|
prompt_id: str # Current prompt execution ID
|
||||||
node_id: str # Node being cached
|
node_id: str # Node being cached
|
||||||
class_type: str # Node class type (e.g., "KSampler")
|
class_type: str # Node class type (e.g., "KSampler")
|
||||||
cache_key: Any # Raw cache key (frozenset structure)
|
cache_key_hash: str # SHA256 hex digest for external storage key
|
||||||
cache_key_bytes: bytes # SHA256 hash for external storage key
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -71,9 +70,9 @@ class CacheProvider(ABC):
|
|||||||
"""
|
"""
|
||||||
Abstract base class for external cache providers.
|
Abstract base class for external cache providers.
|
||||||
|
|
||||||
Thread Safety:
|
Async Safety:
|
||||||
Providers may be called from multiple threads. Implementations
|
Provider methods are called from async context. Implementations
|
||||||
must be thread-safe.
|
can use async I/O (aiohttp, asyncpg, etc.) directly.
|
||||||
|
|
||||||
Error Handling:
|
Error Handling:
|
||||||
All methods are wrapped in try/except by the caller. Exceptions
|
All methods are wrapped in try/except by the caller. Exceptions
|
||||||
@ -81,12 +80,12 @@ class CacheProvider(ABC):
|
|||||||
|
|
||||||
Performance Guidelines:
|
Performance Guidelines:
|
||||||
- on_lookup: Should complete in <500ms (including network)
|
- on_lookup: Should complete in <500ms (including network)
|
||||||
- on_store: Can be async internally (fire-and-forget)
|
- on_store: Fire-and-forget via asyncio.create_task
|
||||||
- should_cache: Should be fast (<1ms), called frequently
|
- should_cache: Should be fast (<1ms), called frequently
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||||
"""
|
"""
|
||||||
Check external storage for cached result.
|
Check external storage for cached result.
|
||||||
|
|
||||||
@ -102,14 +101,14 @@ class CacheProvider(ABC):
|
|||||||
pass
|
pass
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||||
"""
|
"""
|
||||||
Store value to external cache.
|
Store value to external cache.
|
||||||
|
|
||||||
Called AFTER value is stored in local cache.
|
Called AFTER value is stored in local cache.
|
||||||
|
Dispatched as asyncio.create_task (fire-and-forget).
|
||||||
|
|
||||||
Important:
|
Important:
|
||||||
- Can be fire-and-forget (async internally)
|
|
||||||
- Should never block execution
|
- Should never block execution
|
||||||
- Handle serialization failures gracefully
|
- Handle serialization failures gracefully
|
||||||
"""
|
"""
|
||||||
@ -123,7 +122,7 @@ class CacheProvider(ABC):
|
|||||||
Return False to skip external caching for this node.
|
Return False to skip external caching for this node.
|
||||||
|
|
||||||
Implementations can filter based on context.class_type, value size,
|
Implementations can filter based on context.class_type, value size,
|
||||||
or any custom logic. Use estimate_value_size() to get value size.
|
or any custom logic. Use _estimate_value_size() to get value size.
|
||||||
|
|
||||||
Default: Returns True (cache everything).
|
Default: Returns True (cache everything).
|
||||||
"""
|
"""
|
||||||
@ -157,11 +156,11 @@ def register_cache_provider(provider: CacheProvider) -> None:
|
|||||||
global _providers_snapshot
|
global _providers_snapshot
|
||||||
with _providers_lock:
|
with _providers_lock:
|
||||||
if provider in _providers:
|
if provider in _providers:
|
||||||
logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
||||||
return
|
return
|
||||||
_providers.append(provider)
|
_providers.append(provider)
|
||||||
_providers_snapshot = None # Invalidate cache
|
_providers_snapshot = None # Invalidate cache
|
||||||
logger.info(f"Registered cache provider: {provider.__class__.__name__}")
|
_logger.info(f"Registered cache provider: {provider.__class__.__name__}")
|
||||||
|
|
||||||
|
|
||||||
def unregister_cache_provider(provider: CacheProvider) -> None:
|
def unregister_cache_provider(provider: CacheProvider) -> None:
|
||||||
@ -171,13 +170,13 @@ def unregister_cache_provider(provider: CacheProvider) -> None:
|
|||||||
try:
|
try:
|
||||||
_providers.remove(provider)
|
_providers.remove(provider)
|
||||||
_providers_snapshot = None
|
_providers_snapshot = None
|
||||||
logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
|
_logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(f"Provider {provider.__class__.__name__} was not registered")
|
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
|
||||||
|
|
||||||
|
|
||||||
def get_cache_providers() -> Tuple[CacheProvider, ...]:
|
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
|
||||||
"""Get registered providers (cached for performance)."""
|
"""Get registered providers (cached for performance). Internal."""
|
||||||
global _providers_snapshot
|
global _providers_snapshot
|
||||||
snapshot = _providers_snapshot
|
snapshot = _providers_snapshot
|
||||||
if snapshot is not None:
|
if snapshot is not None:
|
||||||
@ -189,13 +188,13 @@ def get_cache_providers() -> Tuple[CacheProvider, ...]:
|
|||||||
return _providers_snapshot
|
return _providers_snapshot
|
||||||
|
|
||||||
|
|
||||||
def has_cache_providers() -> bool:
|
def _has_cache_providers() -> bool:
|
||||||
"""Fast check if any providers registered (no lock)."""
|
"""Fast check if any providers registered (no lock). Internal."""
|
||||||
return bool(_providers)
|
return bool(_providers)
|
||||||
|
|
||||||
|
|
||||||
def clear_cache_providers() -> None:
|
def _clear_cache_providers() -> None:
|
||||||
"""Remove all providers. Useful for testing."""
|
"""Remove all providers. Useful for testing. Internal."""
|
||||||
global _providers_snapshot
|
global _providers_snapshot
|
||||||
with _providers_lock:
|
with _providers_lock:
|
||||||
_providers.clear()
|
_providers.clear()
|
||||||
@ -203,7 +202,7 @@ def clear_cache_providers() -> None:
|
|||||||
|
|
||||||
|
|
||||||
# ============================================================
|
# ============================================================
|
||||||
# Utilities
|
# Internal Utilities
|
||||||
# ============================================================
|
# ============================================================
|
||||||
|
|
||||||
def _canonicalize(obj: Any) -> Any:
|
def _canonicalize(obj: Any) -> Any:
|
||||||
@ -243,11 +242,11 @@ def _canonicalize(obj: Any) -> Any:
|
|||||||
return ("__repr__", repr(obj))
|
return ("__repr__", repr(obj))
|
||||||
|
|
||||||
|
|
||||||
def serialize_cache_key(cache_key: Any) -> bytes:
|
def _serialize_cache_key(cache_key: Any) -> str:
|
||||||
"""
|
"""
|
||||||
Serialize cache key to bytes for external storage.
|
Serialize cache key to a hex digest string for external storage.
|
||||||
|
|
||||||
Returns SHA256 hash suitable for Redis/database keys.
|
Returns SHA256 hex string suitable for Redis/database keys.
|
||||||
|
|
||||||
Note: Uses canonicalize + JSON serialization instead of pickle because
|
Note: Uses canonicalize + JSON serialization instead of pickle because
|
||||||
pickle is NOT deterministic across Python sessions due to hash randomization
|
pickle is NOT deterministic across Python sessions due to hash randomization
|
||||||
@ -257,18 +256,18 @@ def serialize_cache_key(cache_key: Any) -> bytes:
|
|||||||
try:
|
try:
|
||||||
canonical = _canonicalize(cache_key)
|
canonical = _canonicalize(cache_key)
|
||||||
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
|
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
|
||||||
return hashlib.sha256(json_str.encode('utf-8')).digest()
|
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to serialize cache key: {e}")
|
_logger.warning(f"Failed to serialize cache key: {e}")
|
||||||
# Fallback to pickle (non-deterministic but better than nothing)
|
# Fallback to pickle (non-deterministic but better than nothing)
|
||||||
try:
|
try:
|
||||||
serialized = pickle.dumps(cache_key, protocol=4)
|
serialized = pickle.dumps(cache_key, protocol=4)
|
||||||
return hashlib.sha256(serialized).digest()
|
return hashlib.sha256(serialized).hexdigest()
|
||||||
except Exception:
|
except Exception:
|
||||||
return hashlib.sha256(str(id(cache_key)).encode()).digest()
|
return hashlib.sha256(str(id(cache_key)).encode()).hexdigest()
|
||||||
|
|
||||||
|
|
||||||
def contains_nan(obj: Any) -> bool:
|
def _contains_nan(obj: Any) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if cache key contains NaN (indicates uncacheable node).
|
Check if cache key contains NaN (indicates uncacheable node).
|
||||||
|
|
||||||
@ -288,14 +287,14 @@ def contains_nan(obj: Any) -> bool:
|
|||||||
except (TypeError, ValueError):
|
except (TypeError, ValueError):
|
||||||
return False
|
return False
|
||||||
if isinstance(obj, (frozenset, tuple, list, set)):
|
if isinstance(obj, (frozenset, tuple, list, set)):
|
||||||
return any(contains_nan(item) for item in obj)
|
return any(_contains_nan(item) for item in obj)
|
||||||
if isinstance(obj, dict):
|
if isinstance(obj, dict):
|
||||||
return any(contains_nan(k) or contains_nan(v) for k, v in obj.items())
|
return any(_contains_nan(k) or _contains_nan(v) for k, v in obj.items())
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def estimate_value_size(value: CacheValue) -> int:
|
def _estimate_value_size(value: CacheValue) -> int:
|
||||||
"""Estimate serialized size in bytes. Useful for size-based filtering."""
|
"""Estimate serialized size in bytes. Useful for size-based filtering. Internal."""
|
||||||
try:
|
try:
|
||||||
import torch
|
import torch
|
||||||
except ImportError:
|
except ImportError:
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import bisect
|
import bisect
|
||||||
import gc
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
@ -200,15 +201,15 @@ class BasicCache:
|
|||||||
def poll(self, **kwargs):
|
def poll(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _set_immediate(self, node_id, value):
|
async def _set_immediate(self, node_id, value):
|
||||||
assert self.initialized
|
assert self.initialized
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
self.cache[cache_key] = value
|
self.cache[cache_key] = value
|
||||||
|
|
||||||
# Notify external providers
|
# Notify external providers
|
||||||
self._notify_providers_store(node_id, cache_key, value)
|
await self._notify_providers_store(node_id, cache_key, value)
|
||||||
|
|
||||||
def _get_immediate(self, node_id):
|
async def _get_immediate(self, node_id):
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
return None
|
return None
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
@ -218,87 +219,88 @@ class BasicCache:
|
|||||||
return self.cache[cache_key]
|
return self.cache[cache_key]
|
||||||
|
|
||||||
# Check external providers on local miss
|
# Check external providers on local miss
|
||||||
external_result = self._check_providers_lookup(node_id, cache_key)
|
external_result = await self._check_providers_lookup(node_id, cache_key)
|
||||||
if external_result is not None:
|
if external_result is not None:
|
||||||
self.cache[cache_key] = external_result # Warm local cache
|
self.cache[cache_key] = external_result # Warm local cache
|
||||||
return external_result
|
return external_result
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def _notify_providers_store(self, node_id, cache_key, value):
|
async def _notify_providers_store(self, node_id, cache_key, value):
|
||||||
"""Notify external providers of cache store."""
|
"""Notify external providers of cache store (fire-and-forget)."""
|
||||||
from comfy_execution.cache_provider import (
|
from comfy_execution.cache_provider import (
|
||||||
has_cache_providers, get_cache_providers,
|
_has_cache_providers, _get_cache_providers,
|
||||||
CacheContext, CacheValue,
|
CacheContext, CacheValue,
|
||||||
serialize_cache_key, contains_nan, logger
|
_serialize_cache_key, _contains_nan, _logger
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fast exit conditions
|
# Fast exit conditions
|
||||||
if self._is_subcache:
|
if self._is_subcache:
|
||||||
return
|
return
|
||||||
if not has_cache_providers():
|
if not _has_cache_providers():
|
||||||
return
|
return
|
||||||
if not self._is_external_cacheable_value(value):
|
if not self._is_external_cacheable_value(value):
|
||||||
return
|
return
|
||||||
if contains_nan(cache_key):
|
if _contains_nan(cache_key):
|
||||||
return
|
return
|
||||||
|
|
||||||
context = CacheContext(
|
context = self._build_context(node_id, cache_key)
|
||||||
prompt_id=self._current_prompt_id,
|
if context is None:
|
||||||
node_id=node_id,
|
return
|
||||||
class_type=self._get_class_type(node_id),
|
|
||||||
cache_key=cache_key,
|
|
||||||
cache_key_bytes=serialize_cache_key(cache_key)
|
|
||||||
)
|
|
||||||
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
|
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
|
||||||
|
|
||||||
for provider in get_cache_providers():
|
for provider in _get_cache_providers():
|
||||||
try:
|
try:
|
||||||
if provider.should_cache(context, cache_value):
|
if provider.should_cache(context, cache_value):
|
||||||
provider.on_store(context, cache_value)
|
asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
|
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
|
||||||
|
|
||||||
def _check_providers_lookup(self, node_id, cache_key):
|
@staticmethod
|
||||||
|
async def _safe_provider_store(provider, context, cache_value):
|
||||||
|
"""Wrapper for fire-and-forget provider.on_store with error handling."""
|
||||||
|
from comfy_execution.cache_provider import _logger
|
||||||
|
try:
|
||||||
|
await provider.on_store(context, cache_value)
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
|
||||||
|
|
||||||
|
async def _check_providers_lookup(self, node_id, cache_key):
|
||||||
"""Check external providers for cached result."""
|
"""Check external providers for cached result."""
|
||||||
from comfy_execution.cache_provider import (
|
from comfy_execution.cache_provider import (
|
||||||
has_cache_providers, get_cache_providers,
|
_has_cache_providers, _get_cache_providers,
|
||||||
CacheContext, CacheValue,
|
CacheContext, CacheValue,
|
||||||
serialize_cache_key, contains_nan, logger
|
_contains_nan, _logger
|
||||||
)
|
)
|
||||||
|
|
||||||
if self._is_subcache:
|
if self._is_subcache:
|
||||||
return None
|
return None
|
||||||
if not has_cache_providers():
|
if not _has_cache_providers():
|
||||||
return None
|
return None
|
||||||
if contains_nan(cache_key):
|
if _contains_nan(cache_key):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
context = CacheContext(
|
context = self._build_context(node_id, cache_key)
|
||||||
prompt_id=self._current_prompt_id,
|
if context is None:
|
||||||
node_id=node_id,
|
return None
|
||||||
class_type=self._get_class_type(node_id),
|
|
||||||
cache_key=cache_key,
|
|
||||||
cache_key_bytes=serialize_cache_key(cache_key)
|
|
||||||
)
|
|
||||||
|
|
||||||
for provider in get_cache_providers():
|
for provider in _get_cache_providers():
|
||||||
try:
|
try:
|
||||||
if not provider.should_cache(context):
|
if not provider.should_cache(context):
|
||||||
continue
|
continue
|
||||||
result = provider.on_lookup(context)
|
result = await provider.on_lookup(context)
|
||||||
if result is not None:
|
if result is not None:
|
||||||
if not isinstance(result, CacheValue):
|
if not isinstance(result, CacheValue):
|
||||||
logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
|
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
|
||||||
continue
|
continue
|
||||||
if not isinstance(result.outputs, (list, tuple)):
|
if not isinstance(result.outputs, (list, tuple)):
|
||||||
logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
|
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
|
||||||
continue
|
continue
|
||||||
# Import CacheEntry here to avoid circular import at module level
|
# Import CacheEntry here to avoid circular import at module level
|
||||||
from execution import CacheEntry
|
from execution import CacheEntry
|
||||||
return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
|
return CacheEntry(ui=result.ui or {}, outputs=list(result.outputs))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
|
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
|
||||||
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -315,6 +317,16 @@ class BasicCache:
|
|||||||
except Exception:
|
except Exception:
|
||||||
return ''
|
return ''
|
||||||
|
|
||||||
|
def _build_context(self, node_id, cache_key):
|
||||||
|
"""Build CacheContext with hash. Returns None if hashing fails on NaN."""
|
||||||
|
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key
|
||||||
|
return CacheContext(
|
||||||
|
prompt_id=self._current_prompt_id,
|
||||||
|
node_id=node_id,
|
||||||
|
class_type=self._get_class_type(node_id),
|
||||||
|
cache_key_hash=_serialize_cache_key(cache_key)
|
||||||
|
)
|
||||||
|
|
||||||
async def _ensure_subcache(self, node_id, children_ids):
|
async def _ensure_subcache(self, node_id, children_ids):
|
||||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||||
subcache = self.subcaches.get(subcache_key, None)
|
subcache = self.subcaches.get(subcache_key, None)
|
||||||
@ -364,16 +376,16 @@ class HierarchicalCache(BasicCache):
|
|||||||
return None
|
return None
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def get(self, node_id):
|
async def get(self, node_id):
|
||||||
cache = self._get_cache_for(node_id)
|
cache = self._get_cache_for(node_id)
|
||||||
if cache is None:
|
if cache is None:
|
||||||
return None
|
return None
|
||||||
return cache._get_immediate(node_id)
|
return await cache._get_immediate(node_id)
|
||||||
|
|
||||||
def set(self, node_id, value):
|
async def set(self, node_id, value):
|
||||||
cache = self._get_cache_for(node_id)
|
cache = self._get_cache_for(node_id)
|
||||||
assert cache is not None
|
assert cache is not None
|
||||||
cache._set_immediate(node_id, value)
|
await cache._set_immediate(node_id, value)
|
||||||
|
|
||||||
async def ensure_subcache_for(self, node_id, children_ids):
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
cache = self._get_cache_for(node_id)
|
cache = self._get_cache_for(node_id)
|
||||||
@ -394,10 +406,10 @@ class NullCache:
|
|||||||
def poll(self, **kwargs):
|
def poll(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get(self, node_id):
|
async def get(self, node_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set(self, node_id, value):
|
async def set(self, node_id, value):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def ensure_subcache_for(self, node_id, children_ids):
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
@ -429,18 +441,18 @@ class LRUCache(BasicCache):
|
|||||||
del self.children[key]
|
del self.children[key]
|
||||||
self._clean_subcaches()
|
self._clean_subcaches()
|
||||||
|
|
||||||
def get(self, node_id):
|
async def get(self, node_id):
|
||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
return self._get_immediate(node_id)
|
return await self._get_immediate(node_id)
|
||||||
|
|
||||||
def _mark_used(self, node_id):
|
def _mark_used(self, node_id):
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
if cache_key is not None:
|
if cache_key is not None:
|
||||||
self.used_generation[cache_key] = self.generation
|
self.used_generation[cache_key] = self.generation
|
||||||
|
|
||||||
def set(self, node_id, value):
|
async def set(self, node_id, value):
|
||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
return self._set_immediate(node_id, value)
|
return await self._set_immediate(node_id, value)
|
||||||
|
|
||||||
async def ensure_subcache_for(self, node_id, children_ids):
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
# Just uses subcaches for tracking 'live' nodes
|
# Just uses subcaches for tracking 'live' nodes
|
||||||
@ -480,13 +492,13 @@ class RAMPressureCache(LRUCache):
|
|||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
self._clean_subcaches()
|
self._clean_subcaches()
|
||||||
|
|
||||||
def set(self, node_id, value):
|
async def set(self, node_id, value):
|
||||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
super().set(node_id, value)
|
await super().set(node_id, value)
|
||||||
|
|
||||||
def get(self, node_id):
|
async def get(self, node_id):
|
||||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
return super().get(node_id)
|
return await super().get(node_id)
|
||||||
|
|
||||||
def poll(self, ram_headroom):
|
def poll(self, ram_headroom):
|
||||||
def _ram_gb():
|
def _ram_gb():
|
||||||
|
|||||||
10
execution.py
10
execution.py
@ -414,7 +414,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
cached = caches.outputs.get(unique_id)
|
cached = await caches.outputs.get(unique_id)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
cached_ui = cached.ui or {}
|
cached_ui = cached.ui or {}
|
||||||
@ -470,10 +470,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
|
||||||
obj = caches.objects.get(unique_id)
|
obj = await caches.objects.get(unique_id)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
obj = class_def()
|
obj = class_def()
|
||||||
caches.objects.set(unique_id, obj)
|
await caches.objects.set(unique_id, obj)
|
||||||
|
|
||||||
if issubclass(class_def, _ComfyNodeInternal):
|
if issubclass(class_def, _ComfyNodeInternal):
|
||||||
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||||
@ -575,7 +575,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
|
|
||||||
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
||||||
execution_list.cache_update(unique_id, cache_entry)
|
execution_list.cache_update(unique_id, cache_entry)
|
||||||
caches.outputs.set(unique_id, cache_entry)
|
await caches.outputs.set(unique_id, cache_entry)
|
||||||
|
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
logging.info("Processing interrupted")
|
logging.info("Processing interrupted")
|
||||||
@ -720,7 +720,7 @@ class PromptExecutor:
|
|||||||
|
|
||||||
cached_nodes = []
|
cached_nodes = []
|
||||||
for node_id in prompt:
|
for node_id in prompt:
|
||||||
if self.caches.outputs.get(node_id) is not None:
|
if await self.caches.outputs.get(node_id) is not None:
|
||||||
cached_nodes.append(node_id)
|
cached_nodes.append(node_id)
|
||||||
|
|
||||||
comfy.model_management.cleanup_models_gc()
|
comfy.model_management.cleanup_models_gc()
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user