mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 10:17:31 +08:00
Address review feedback from guill: - Rename _contains_nan to _contains_self_unequal, use not (x == x) instead of math.isnan to catch any self-unequal value - Remove Unhashable and repr() fallbacks from _canonicalize; raise ValueError for unknown types so _serialize_cache_key returns None and external caching is skipped (fail-closed) - Update tests for renamed function and new fail-closed behavior Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
136 lines
4.6 KiB
Python
136 lines
4.6 KiB
Python
from typing import Any, Optional, Tuple, List
|
|
import hashlib
|
|
import json
|
|
import logging
|
|
import threading
|
|
|
|
# Public types — source of truth is comfy_api.latest._caching
|
|
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue
|
|
|
|
_logger = logging.getLogger(__name__)
|
|
|
|
|
|
_providers: List[CacheProvider] = []
|
|
_providers_lock = threading.Lock()
|
|
_providers_snapshot: Tuple[CacheProvider, ...] = ()
|
|
|
|
|
|
def register_cache_provider(provider: CacheProvider) -> None:
|
|
"""Register an external cache provider. Providers are called in registration order."""
|
|
global _providers_snapshot
|
|
with _providers_lock:
|
|
if provider in _providers:
|
|
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
|
return
|
|
_providers.append(provider)
|
|
_providers_snapshot = tuple(_providers)
|
|
_logger.info(f"Registered cache provider: {provider.__class__.__name__}")
|
|
|
|
|
|
def unregister_cache_provider(provider: CacheProvider) -> None:
|
|
global _providers_snapshot
|
|
with _providers_lock:
|
|
try:
|
|
_providers.remove(provider)
|
|
_providers_snapshot = tuple(_providers)
|
|
_logger.info(f"Unregistered cache provider: {provider.__class__.__name__}")
|
|
except ValueError:
|
|
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
|
|
|
|
|
|
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
|
|
return _providers_snapshot
|
|
|
|
|
|
def _has_cache_providers() -> bool:
|
|
return bool(_providers_snapshot)
|
|
|
|
|
|
def _clear_cache_providers() -> None:
|
|
global _providers_snapshot
|
|
with _providers_lock:
|
|
_providers.clear()
|
|
_providers_snapshot = ()
|
|
|
|
|
|
def _canonicalize(obj: Any) -> Any:
|
|
# Convert to canonical JSON-serializable form with deterministic ordering.
|
|
# Frozensets have non-deterministic iteration order between Python sessions.
|
|
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
|
|
# _serialize_cache_key returns None and external caching is skipped.
|
|
if isinstance(obj, frozenset):
|
|
return ("__frozenset__", sorted(
|
|
[_canonicalize(item) for item in obj],
|
|
key=lambda x: json.dumps(x, sort_keys=True)
|
|
))
|
|
elif isinstance(obj, set):
|
|
return ("__set__", sorted(
|
|
[_canonicalize(item) for item in obj],
|
|
key=lambda x: json.dumps(x, sort_keys=True)
|
|
))
|
|
elif isinstance(obj, tuple):
|
|
return ("__tuple__", [_canonicalize(item) for item in obj])
|
|
elif isinstance(obj, list):
|
|
return [_canonicalize(item) for item in obj]
|
|
elif isinstance(obj, dict):
|
|
return {str(k): _canonicalize(v) for k, v in sorted(obj.items())}
|
|
elif isinstance(obj, (int, float, str, bool, type(None))):
|
|
return obj
|
|
elif isinstance(obj, bytes):
|
|
return ("__bytes__", obj.hex())
|
|
else:
|
|
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
|
|
|
|
|
|
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
|
|
# Returns deterministic SHA256 hex digest, or None on failure.
|
|
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
|
|
try:
|
|
canonical = _canonicalize(cache_key)
|
|
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
|
|
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
|
|
except Exception as e:
|
|
_logger.warning(f"Failed to serialize cache key: {e}")
|
|
return None
|
|
|
|
|
|
def _contains_self_unequal(obj: Any) -> bool:
|
|
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
|
|
# never hit locally, but serialized form would match externally. Skip these.
|
|
try:
|
|
if not (obj == obj):
|
|
return True
|
|
except Exception:
|
|
return True
|
|
if isinstance(obj, (frozenset, tuple, list, set)):
|
|
return any(_contains_self_unequal(item) for item in obj)
|
|
if isinstance(obj, dict):
|
|
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
|
|
if hasattr(obj, 'value'):
|
|
return _contains_self_unequal(obj.value)
|
|
return False
|
|
|
|
|
|
def _estimate_value_size(value: CacheValue) -> int:
|
|
try:
|
|
import torch
|
|
except ImportError:
|
|
return 0
|
|
|
|
total = 0
|
|
|
|
def estimate(obj):
|
|
nonlocal total
|
|
if isinstance(obj, torch.Tensor):
|
|
total += obj.numel() * obj.element_size()
|
|
elif isinstance(obj, dict):
|
|
for v in obj.values():
|
|
estimate(v)
|
|
elif isinstance(obj, (list, tuple)):
|
|
for item in obj:
|
|
estimate(item)
|
|
|
|
for output in value.outputs:
|
|
estimate(output)
|
|
return total
|