mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-31 08:40:19 +08:00
feat: Add CacheProvider API for external distributed caching
Introduces a public API for external cache providers, enabling distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods). New files: - comfy_execution/cache_provider.py: CacheProvider ABC, CacheContext/CacheValue dataclasses, thread-safe provider registry, serialization utilities Modified files: - comfy_execution/caching.py: Add provider hooks to BasicCache (_notify_providers_store, _check_providers_lookup), subcache exclusion, prompt ID propagation - execution.py: Add prompt lifecycle hooks (on_prompt_start/on_prompt_end) to PromptExecutor, set _current_prompt_id on caches Key features: - Local-first caching (check local before external for performance) - NaN detection to prevent incorrect external cache hits - Subcache exclusion (ephemeral subgraph results not cached externally) - Thread-safe provider snapshot caching - Graceful error handling (provider errors logged, never break execution) 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
parent
ec0a832acb
commit
6540aa0400
267
comfy_execution/cache_provider.py
Normal file
267
comfy_execution/cache_provider.py
Normal file
@ -0,0 +1,267 @@
|
|||||||
|
"""
|
||||||
|
External Cache Provider API for distributed caching.
|
||||||
|
|
||||||
|
This module provides a public API for external cache providers, enabling
|
||||||
|
distributed caching across multiple ComfyUI instances (e.g., Kubernetes pods).
|
||||||
|
|
||||||
|
Example usage:
|
||||||
|
from comfy_execution.cache_provider import (
|
||||||
|
CacheProvider, CacheContext, CacheValue, register_cache_provider
|
||||||
|
)
|
||||||
|
|
||||||
|
class MyRedisProvider(CacheProvider):
|
||||||
|
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||||
|
# Check Redis/GCS for cached result
|
||||||
|
...
|
||||||
|
|
||||||
|
def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||||
|
# Store to Redis/GCS (can be async internally)
|
||||||
|
...
|
||||||
|
|
||||||
|
register_cache_provider(MyRedisProvider())
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Any, Optional, Tuple, List
|
||||||
|
from dataclasses import dataclass
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
import hashlib
|
||||||
|
import pickle
|
||||||
|
import math
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Data Classes
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheContext:
|
||||||
|
"""Context passed to provider methods."""
|
||||||
|
prompt_id: str # Current prompt execution ID
|
||||||
|
node_id: str # Node being cached
|
||||||
|
class_type: str # Node class type (e.g., "KSampler")
|
||||||
|
cache_key: Any # Raw cache key (frozenset structure)
|
||||||
|
cache_key_bytes: bytes # SHA256 hash for external storage key
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheValue:
|
||||||
|
"""
|
||||||
|
Value stored/retrieved from external cache.
|
||||||
|
|
||||||
|
Note: UI data is intentionally excluded - it contains pod-local
|
||||||
|
file paths that aren't portable across instances.
|
||||||
|
"""
|
||||||
|
outputs: list # The tensor/value outputs
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Provider Interface
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
class CacheProvider(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base class for external cache providers.
|
||||||
|
|
||||||
|
Thread Safety:
|
||||||
|
Providers may be called from multiple threads. Implementations
|
||||||
|
must be thread-safe.
|
||||||
|
|
||||||
|
Error Handling:
|
||||||
|
All methods are wrapped in try/except by the caller. Exceptions
|
||||||
|
are logged but never propagate to break execution.
|
||||||
|
|
||||||
|
Performance Guidelines:
|
||||||
|
- on_lookup: Should complete in <500ms (including network)
|
||||||
|
- on_store: Can be async internally (fire-and-forget)
|
||||||
|
- should_cache: Should be fast (<1ms), called frequently
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||||
|
"""
|
||||||
|
Check external storage for cached result.
|
||||||
|
|
||||||
|
Called AFTER local cache miss (local-first for performance).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
CacheValue if found externally, None otherwise.
|
||||||
|
|
||||||
|
Important:
|
||||||
|
- Return None on any error (don't raise)
|
||||||
|
- Validate data integrity before returning
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||||
|
"""
|
||||||
|
Store value to external cache.
|
||||||
|
|
||||||
|
Called AFTER value is stored in local cache.
|
||||||
|
|
||||||
|
Important:
|
||||||
|
- Can be fire-and-forget (async internally)
|
||||||
|
- Should never block execution
|
||||||
|
- Handle serialization failures gracefully
|
||||||
|
"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
|
||||||
|
"""
|
||||||
|
Filter which nodes should be externally cached.
|
||||||
|
|
||||||
|
Called before on_lookup (value=None) and on_store (value provided).
|
||||||
|
Return False to skip external caching for this node.
|
||||||
|
|
||||||
|
Common filters:
|
||||||
|
- By class_type: Only expensive nodes (KSampler, VAEDecode)
|
||||||
|
- By size: Skip small values (< 1MB)
|
||||||
|
|
||||||
|
Default: Returns True (cache everything).
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_prompt_start(self, prompt_id: str) -> None:
|
||||||
|
"""Called when prompt execution begins. Optional."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_prompt_end(self, prompt_id: str) -> None:
|
||||||
|
"""Called when prompt execution ends. Optional."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Provider Registry
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
_providers: List[CacheProvider] = []
|
||||||
|
_providers_lock = threading.Lock()
|
||||||
|
_providers_snapshot: Optional[Tuple[CacheProvider, ...]] = None
|
||||||
|
|
||||||
|
|
||||||
|
def register_cache_provider(provider: CacheProvider) -> None:
|
||||||
|
"""
|
||||||
|
Register an external cache provider.
|
||||||
|
|
||||||
|
Providers are called in registration order. First provider to return
|
||||||
|
a result from on_lookup wins.
|
||||||
|
"""
|
||||||
|
global _providers_snapshot
|
||||||
|
with _providers_lock:
|
||||||
|
if provider in _providers:
|
||||||
|
logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
||||||
|
return
|
||||||
|
_providers.append(provider)
|
||||||
|
_providers_snapshot = None # Invalidate cache
|
||||||
|
logger.info(f"Registered cache provider: {provider.__class__.__name__}")
|
||||||
|
|
||||||
|
|
||||||
|
def unregister_cache_provider(provider: CacheProvider) -> None:
|
||||||
|
"""Remove a previously registered provider."""
|
||||||
|
global _providers_snapshot
|
||||||
|
with _providers_lock:
|
||||||
|
try:
|
||||||
|
_providers.remove(provider)
|
||||||
|
_providers_snapshot = None
|
||||||
|
logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
|
||||||
|
except ValueError:
|
||||||
|
logger.warning(f"Provider {provider.__class__.__name__} was not registered")
|
||||||
|
|
||||||
|
|
||||||
|
def get_cache_providers() -> Tuple[CacheProvider, ...]:
|
||||||
|
"""Get registered providers (cached for performance)."""
|
||||||
|
global _providers_snapshot
|
||||||
|
snapshot = _providers_snapshot
|
||||||
|
if snapshot is not None:
|
||||||
|
return snapshot
|
||||||
|
with _providers_lock:
|
||||||
|
if _providers_snapshot is not None:
|
||||||
|
return _providers_snapshot
|
||||||
|
_providers_snapshot = tuple(_providers)
|
||||||
|
return _providers_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def has_cache_providers() -> bool:
|
||||||
|
"""Fast check if any providers registered (no lock)."""
|
||||||
|
return bool(_providers)
|
||||||
|
|
||||||
|
|
||||||
|
def clear_cache_providers() -> None:
|
||||||
|
"""Remove all providers. Useful for testing."""
|
||||||
|
global _providers_snapshot
|
||||||
|
with _providers_lock:
|
||||||
|
_providers.clear()
|
||||||
|
_providers_snapshot = None
|
||||||
|
|
||||||
|
|
||||||
|
# ============================================================
|
||||||
|
# Utilities
|
||||||
|
# ============================================================
|
||||||
|
|
||||||
|
def serialize_cache_key(cache_key: Any) -> bytes:
|
||||||
|
"""
|
||||||
|
Serialize cache key to bytes for external storage.
|
||||||
|
|
||||||
|
Returns SHA256 hash suitable for Redis/database keys.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
serialized = pickle.dumps(cache_key, protocol=4)
|
||||||
|
return hashlib.sha256(serialized).digest()
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Failed to serialize cache key: {e}")
|
||||||
|
return hashlib.sha256(str(id(cache_key)).encode()).digest()
|
||||||
|
|
||||||
|
|
||||||
|
def contains_nan(obj: Any) -> bool:
|
||||||
|
"""
|
||||||
|
Check if cache key contains NaN (indicates uncacheable node).
|
||||||
|
|
||||||
|
NaN != NaN in Python, so local cache never hits. But serialized
|
||||||
|
NaN would match, causing incorrect external hits. Must skip these.
|
||||||
|
"""
|
||||||
|
if isinstance(obj, float):
|
||||||
|
try:
|
||||||
|
return math.isnan(obj)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return False
|
||||||
|
if hasattr(obj, 'value'): # Unhashable class
|
||||||
|
val = getattr(obj, 'value', None)
|
||||||
|
if isinstance(val, float):
|
||||||
|
try:
|
||||||
|
return math.isnan(val)
|
||||||
|
except (TypeError, ValueError):
|
||||||
|
return False
|
||||||
|
if isinstance(obj, (frozenset, tuple, list, set)):
|
||||||
|
return any(contains_nan(item) for item in obj)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return any(contains_nan(k) or contains_nan(v) for k, v in obj.items())
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def estimate_value_size(value: CacheValue) -> int:
|
||||||
|
"""Estimate serialized size in bytes. Useful for size-based filtering."""
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
def estimate(obj):
|
||||||
|
nonlocal total
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
total += obj.numel() * obj.element_size()
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
for v in obj.values():
|
||||||
|
estimate(v)
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
for item in obj:
|
||||||
|
estimate(item)
|
||||||
|
|
||||||
|
for output in value.outputs:
|
||||||
|
estimate(output)
|
||||||
|
return total
|
||||||
@ -155,6 +155,10 @@ class BasicCache:
|
|||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.subcaches = {}
|
self.subcaches = {}
|
||||||
|
|
||||||
|
# External cache provider support
|
||||||
|
self._is_subcache = False
|
||||||
|
self._current_prompt_id = ''
|
||||||
|
|
||||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||||
@ -201,20 +205,123 @@ class BasicCache:
|
|||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
self.cache[cache_key] = value
|
self.cache[cache_key] = value
|
||||||
|
|
||||||
|
# Notify external providers
|
||||||
|
self._notify_providers_store(node_id, cache_key, value)
|
||||||
|
|
||||||
def _get_immediate(self, node_id):
|
def _get_immediate(self, node_id):
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
return None
|
return None
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
|
||||||
|
# Check local cache first (fast path)
|
||||||
if cache_key in self.cache:
|
if cache_key in self.cache:
|
||||||
return self.cache[cache_key]
|
return self.cache[cache_key]
|
||||||
else:
|
|
||||||
|
# Check external providers on local miss
|
||||||
|
external_result = self._check_providers_lookup(node_id, cache_key)
|
||||||
|
if external_result is not None:
|
||||||
|
self.cache[cache_key] = external_result # Warm local cache
|
||||||
|
return external_result
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _notify_providers_store(self, node_id, cache_key, value):
|
||||||
|
"""Notify external providers of cache store."""
|
||||||
|
from comfy_execution.cache_provider import (
|
||||||
|
has_cache_providers, get_cache_providers,
|
||||||
|
CacheContext, CacheValue,
|
||||||
|
serialize_cache_key, contains_nan, logger
|
||||||
|
)
|
||||||
|
|
||||||
|
# Fast exit conditions
|
||||||
|
if self._is_subcache:
|
||||||
|
return
|
||||||
|
if not has_cache_providers():
|
||||||
|
return
|
||||||
|
if not self._is_cacheable_value(value):
|
||||||
|
return
|
||||||
|
if contains_nan(cache_key):
|
||||||
|
return
|
||||||
|
|
||||||
|
context = CacheContext(
|
||||||
|
prompt_id=self._current_prompt_id,
|
||||||
|
node_id=node_id,
|
||||||
|
class_type=self._get_class_type(node_id),
|
||||||
|
cache_key=cache_key,
|
||||||
|
cache_key_bytes=serialize_cache_key(cache_key)
|
||||||
|
)
|
||||||
|
cache_value = CacheValue(outputs=value.outputs)
|
||||||
|
|
||||||
|
for provider in get_cache_providers():
|
||||||
|
try:
|
||||||
|
if provider.should_cache(context, cache_value):
|
||||||
|
provider.on_store(context, cache_value)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
|
||||||
|
|
||||||
|
def _check_providers_lookup(self, node_id, cache_key):
|
||||||
|
"""Check external providers for cached result."""
|
||||||
|
from comfy_execution.cache_provider import (
|
||||||
|
has_cache_providers, get_cache_providers,
|
||||||
|
CacheContext, CacheValue,
|
||||||
|
serialize_cache_key, contains_nan, logger
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._is_subcache:
|
||||||
return None
|
return None
|
||||||
|
if not has_cache_providers():
|
||||||
|
return None
|
||||||
|
if contains_nan(cache_key):
|
||||||
|
return None
|
||||||
|
|
||||||
|
context = CacheContext(
|
||||||
|
prompt_id=self._current_prompt_id,
|
||||||
|
node_id=node_id,
|
||||||
|
class_type=self._get_class_type(node_id),
|
||||||
|
cache_key=cache_key,
|
||||||
|
cache_key_bytes=serialize_cache_key(cache_key)
|
||||||
|
)
|
||||||
|
|
||||||
|
for provider in get_cache_providers():
|
||||||
|
try:
|
||||||
|
if not provider.should_cache(context):
|
||||||
|
continue
|
||||||
|
result = provider.on_lookup(context)
|
||||||
|
if result is not None:
|
||||||
|
if not isinstance(result, CacheValue):
|
||||||
|
logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
|
||||||
|
continue
|
||||||
|
if not isinstance(result.outputs, (list, tuple)):
|
||||||
|
logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
|
||||||
|
continue
|
||||||
|
# Import CacheEntry here to avoid circular import at module level
|
||||||
|
from execution import CacheEntry
|
||||||
|
return CacheEntry(ui={}, outputs=list(result.outputs))
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_cacheable_value(self, value):
|
||||||
|
"""Check if value is a CacheEntry (not objects cache)."""
|
||||||
|
return hasattr(value, 'outputs') and hasattr(value, 'ui')
|
||||||
|
|
||||||
|
def _get_class_type(self, node_id):
|
||||||
|
"""Get class_type for a node."""
|
||||||
|
if not self.initialized or not self.dynprompt:
|
||||||
|
return ''
|
||||||
|
try:
|
||||||
|
return self.dynprompt.get_node(node_id).get('class_type', '')
|
||||||
|
except Exception:
|
||||||
|
return ''
|
||||||
|
|
||||||
async def _ensure_subcache(self, node_id, children_ids):
|
async def _ensure_subcache(self, node_id, children_ids):
|
||||||
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
subcache_key = self.cache_key_set.get_subcache_key(node_id)
|
||||||
subcache = self.subcaches.get(subcache_key, None)
|
subcache = self.subcaches.get(subcache_key, None)
|
||||||
if subcache is None:
|
if subcache is None:
|
||||||
subcache = BasicCache(self.key_class)
|
subcache = BasicCache(self.key_class)
|
||||||
|
subcache._is_subcache = True # Mark as subcache - excludes from external caching
|
||||||
|
subcache._current_prompt_id = self._current_prompt_id # Propagate prompt ID
|
||||||
self.subcaches[subcache_key] = subcache
|
self.subcaches[subcache_key] = subcache
|
||||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||||
return subcache
|
return subcache
|
||||||
|
|||||||
137
execution.py
137
execution.py
@ -669,6 +669,22 @@ class PromptExecutor:
|
|||||||
}
|
}
|
||||||
self.add_message("execution_error", mes, broadcast=False)
|
self.add_message("execution_error", mes, broadcast=False)
|
||||||
|
|
||||||
|
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
|
||||||
|
"""Notify external cache providers of prompt lifecycle events."""
|
||||||
|
from comfy_execution.cache_provider import has_cache_providers, get_cache_providers, logger
|
||||||
|
|
||||||
|
if not has_cache_providers():
|
||||||
|
return
|
||||||
|
|
||||||
|
for provider in get_cache_providers():
|
||||||
|
try:
|
||||||
|
if event == "start":
|
||||||
|
provider.on_prompt_start(prompt_id)
|
||||||
|
elif event == "end":
|
||||||
|
provider.on_prompt_end(prompt_id)
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
|
||||||
|
|
||||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||||
|
|
||||||
@ -685,66 +701,77 @@ class PromptExecutor:
|
|||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
with torch.inference_mode():
|
# Set prompt ID on caches for external provider integration
|
||||||
dynamic_prompt = DynamicPrompt(prompt)
|
for cache in self.caches.all:
|
||||||
reset_progress_state(prompt_id, dynamic_prompt)
|
cache._current_prompt_id = prompt_id
|
||||||
add_progress_handler(WebUIProgressHandler(self.server))
|
|
||||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
|
||||||
for cache in self.caches.all:
|
|
||||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
|
||||||
cache.clean_unused()
|
|
||||||
|
|
||||||
cached_nodes = []
|
# Notify external cache providers of prompt start
|
||||||
for node_id in prompt:
|
self._notify_prompt_lifecycle("start", prompt_id)
|
||||||
if self.caches.outputs.get(node_id) is not None:
|
|
||||||
cached_nodes.append(node_id)
|
|
||||||
|
|
||||||
comfy.model_management.cleanup_models_gc()
|
try:
|
||||||
self.add_message("execution_cached",
|
with torch.inference_mode():
|
||||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
broadcast=False)
|
reset_progress_state(prompt_id, dynamic_prompt)
|
||||||
pending_subgraph_results = {}
|
add_progress_handler(WebUIProgressHandler(self.server))
|
||||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||||
ui_node_outputs = {}
|
for cache in self.caches.all:
|
||||||
executed = set()
|
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
cache.clean_unused()
|
||||||
current_outputs = self.caches.outputs.all_node_ids()
|
|
||||||
for node_id in list(execute_outputs):
|
|
||||||
execution_list.add_node(node_id)
|
|
||||||
|
|
||||||
while not execution_list.is_empty():
|
cached_nodes = []
|
||||||
node_id, error, ex = await execution_list.stage_node_execution()
|
for node_id in prompt:
|
||||||
if error is not None:
|
if self.caches.outputs.get(node_id) is not None:
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
cached_nodes.append(node_id)
|
||||||
break
|
|
||||||
|
|
||||||
assert node_id is not None, "Node ID should not be None at this point"
|
comfy.model_management.cleanup_models_gc()
|
||||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
self.add_message("execution_cached",
|
||||||
self.success = result != ExecutionResult.FAILURE
|
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||||
if result == ExecutionResult.FAILURE:
|
broadcast=False)
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
pending_subgraph_results = {}
|
||||||
break
|
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||||
elif result == ExecutionResult.PENDING:
|
ui_node_outputs = {}
|
||||||
execution_list.unstage_node_execution()
|
executed = set()
|
||||||
else: # result == ExecutionResult.SUCCESS:
|
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||||
execution_list.complete_node_execution()
|
current_outputs = self.caches.outputs.all_node_ids()
|
||||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
for node_id in list(execute_outputs):
|
||||||
else:
|
execution_list.add_node(node_id)
|
||||||
# Only execute when the while-loop ends without break
|
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
|
||||||
|
|
||||||
ui_outputs = {}
|
while not execution_list.is_empty():
|
||||||
meta_outputs = {}
|
node_id, error, ex = await execution_list.stage_node_execution()
|
||||||
for node_id, ui_info in ui_node_outputs.items():
|
if error is not None:
|
||||||
ui_outputs[node_id] = ui_info["output"]
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
meta_outputs[node_id] = ui_info["meta"]
|
break
|
||||||
self.history_result = {
|
|
||||||
"outputs": ui_outputs,
|
assert node_id is not None, "Node ID should not be None at this point"
|
||||||
"meta": meta_outputs,
|
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||||
}
|
self.success = result != ExecutionResult.FAILURE
|
||||||
self.server.last_node_id = None
|
if result == ExecutionResult.FAILURE:
|
||||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
comfy.model_management.unload_all_models()
|
break
|
||||||
|
elif result == ExecutionResult.PENDING:
|
||||||
|
execution_list.unstage_node_execution()
|
||||||
|
else: # result == ExecutionResult.SUCCESS:
|
||||||
|
execution_list.complete_node_execution()
|
||||||
|
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||||
|
else:
|
||||||
|
# Only execute when the while-loop ends without break
|
||||||
|
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||||
|
|
||||||
|
ui_outputs = {}
|
||||||
|
meta_outputs = {}
|
||||||
|
for node_id, ui_info in ui_node_outputs.items():
|
||||||
|
ui_outputs[node_id] = ui_info["output"]
|
||||||
|
meta_outputs[node_id] = ui_info["meta"]
|
||||||
|
self.history_result = {
|
||||||
|
"outputs": ui_outputs,
|
||||||
|
"meta": meta_outputs,
|
||||||
|
}
|
||||||
|
self.server.last_node_id = None
|
||||||
|
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||||
|
comfy.model_management.unload_all_models()
|
||||||
|
finally:
|
||||||
|
# Notify external cache providers of prompt end
|
||||||
|
self._notify_prompt_lifecycle("end", prompt_id)
|
||||||
|
|
||||||
|
|
||||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user