mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-21 09:03:37 +08:00
Merge 2bea0ee5d7 into 16cd8d8a8f
This commit is contained in:
commit
9c8eb96254
@ -1,7 +1,6 @@
|
|||||||
import asyncio
|
import asyncio
|
||||||
import bisect
|
import bisect
|
||||||
import gc
|
import gc
|
||||||
import itertools
|
|
||||||
import psutil
|
import psutil
|
||||||
import time
|
import time
|
||||||
import torch
|
import torch
|
||||||
@ -17,6 +16,7 @@ NODE_CLASS_CONTAINS_UNIQUE_ID: Dict[str, bool] = {}
|
|||||||
|
|
||||||
|
|
||||||
def include_unique_id_in_input(class_type: str) -> bool:
|
def include_unique_id_in_input(class_type: str) -> bool:
|
||||||
|
"""Return whether a node class includes UNIQUE_ID among its hidden inputs."""
|
||||||
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
if class_type in NODE_CLASS_CONTAINS_UNIQUE_ID:
|
||||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
@ -24,52 +24,181 @@ def include_unique_id_in_input(class_type: str) -> bool:
|
|||||||
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
return NODE_CLASS_CONTAINS_UNIQUE_ID[class_type]
|
||||||
|
|
||||||
class CacheKeySet(ABC):
|
class CacheKeySet(ABC):
|
||||||
|
"""Base helper for building and storing cache keys for prompt nodes."""
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
"""Initialize cache-key storage for a dynamic prompt execution pass."""
|
||||||
self.keys = {}
|
self.keys = {}
|
||||||
self.subcache_keys = {}
|
self.subcache_keys = {}
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
|
"""Populate cache keys for the provided node ids."""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
def all_node_ids(self):
|
def all_node_ids(self):
|
||||||
|
"""Return the set of node ids currently tracked by this key set."""
|
||||||
return set(self.keys.keys())
|
return set(self.keys.keys())
|
||||||
|
|
||||||
def get_used_keys(self):
|
def get_used_keys(self):
|
||||||
|
"""Return the computed cache keys currently in use."""
|
||||||
return self.keys.values()
|
return self.keys.values()
|
||||||
|
|
||||||
def get_used_subcache_keys(self):
|
def get_used_subcache_keys(self):
|
||||||
|
"""Return the computed subcache keys currently in use."""
|
||||||
return self.subcache_keys.values()
|
return self.subcache_keys.values()
|
||||||
|
|
||||||
def get_data_key(self, node_id):
|
def get_data_key(self, node_id):
|
||||||
|
"""Return the cache key for a node, if present."""
|
||||||
return self.keys.get(node_id, None)
|
return self.keys.get(node_id, None)
|
||||||
|
|
||||||
def get_subcache_key(self, node_id):
|
def get_subcache_key(self, node_id):
|
||||||
|
"""Return the subcache key for a node, if present."""
|
||||||
return self.subcache_keys.get(node_id, None)
|
return self.subcache_keys.get(node_id, None)
|
||||||
|
|
||||||
class Unhashable:
|
class Unhashable:
|
||||||
def __init__(self):
|
"""Hashable identity sentinel for values that cannot be represented safely in cache keys."""
|
||||||
self.value = float("NaN")
|
pass
|
||||||
|
|
||||||
def to_hashable(obj):
|
|
||||||
# So that we don't infinitely recurse since frozenset and tuples
|
def _sanitized_sort_key(obj, depth=0, max_depth=32):
|
||||||
# are Sequences.
|
"""Return a deterministic ordering key for sanitized built-in container content."""
|
||||||
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
if depth >= max_depth:
|
||||||
return obj
|
return ("MAX_DEPTH",)
|
||||||
elif isinstance(obj, Mapping):
|
|
||||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
obj_type = type(obj)
|
||||||
elif isinstance(obj, Sequence):
|
if obj_type is Unhashable:
|
||||||
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
return ("UNHASHABLE",)
|
||||||
|
elif obj_type in (int, float, str, bool, bytes, type(None)):
|
||||||
|
return (obj_type.__module__, obj_type.__qualname__, repr(obj))
|
||||||
|
elif obj_type is dict:
|
||||||
|
items = [
|
||||||
|
(
|
||||||
|
_sanitized_sort_key(k, depth + 1, max_depth),
|
||||||
|
_sanitized_sort_key(v, depth + 1, max_depth),
|
||||||
|
)
|
||||||
|
for k, v in obj.items()
|
||||||
|
]
|
||||||
|
items.sort()
|
||||||
|
return ("dict", tuple(items))
|
||||||
|
elif obj_type is list:
|
||||||
|
return ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
|
||||||
|
elif obj_type is tuple:
|
||||||
|
return ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
|
||||||
|
elif obj_type is set:
|
||||||
|
return ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
|
||||||
|
elif obj_type is frozenset:
|
||||||
|
return ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
|
||||||
|
else:
|
||||||
|
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
|
||||||
|
|
||||||
|
|
||||||
|
def _sanitize_signature_input(obj, depth=0, max_depth=32, seen=None):
|
||||||
|
"""Normalize signature inputs to safe built-in containers.
|
||||||
|
|
||||||
|
Preserves built-in container type, replaces opaque runtime values with
|
||||||
|
Unhashable(), and stops safely on cycles or excessive depth.
|
||||||
|
"""
|
||||||
|
if depth >= max_depth:
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
if seen is None:
|
||||||
|
seen = set()
|
||||||
|
|
||||||
|
obj_type = type(obj)
|
||||||
|
if obj_type in (dict, list, tuple, set, frozenset):
|
||||||
|
obj_id = id(obj)
|
||||||
|
if obj_id in seen:
|
||||||
|
return Unhashable()
|
||||||
|
next_seen = seen | {obj_id}
|
||||||
|
|
||||||
|
if obj_type in (int, float, str, bool, bytes, type(None)):
|
||||||
|
return obj
|
||||||
|
elif obj_type is dict:
|
||||||
|
sanitized_items = [
|
||||||
|
(
|
||||||
|
_sanitize_signature_input(key, depth + 1, max_depth, next_seen),
|
||||||
|
_sanitize_signature_input(value, depth + 1, max_depth, next_seen),
|
||||||
|
)
|
||||||
|
for key, value in obj.items()
|
||||||
|
]
|
||||||
|
sanitized_items.sort(
|
||||||
|
key=lambda kv: (
|
||||||
|
_sanitized_sort_key(kv[0], depth + 1, max_depth),
|
||||||
|
_sanitized_sort_key(kv[1], depth + 1, max_depth),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return {key: value for key, value in sanitized_items}
|
||||||
|
elif obj_type is list:
|
||||||
|
return [_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj]
|
||||||
|
elif obj_type is tuple:
|
||||||
|
return tuple(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
|
||||||
|
elif obj_type is set:
|
||||||
|
return {_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj}
|
||||||
|
elif obj_type is frozenset:
|
||||||
|
return frozenset(_sanitize_signature_input(item, depth + 1, max_depth, next_seen) for item in obj)
|
||||||
|
else:
|
||||||
|
# Execution-cache signatures should be built from prompt-safe values.
|
||||||
|
# If a custom node injects a runtime object here, mark it unhashable so
|
||||||
|
# the node won't reuse stale cache entries across runs, but do not walk
|
||||||
|
# the foreign object and risk crashing on custom container semantics.
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
def to_hashable(obj, depth=0, max_depth=32, seen=None):
|
||||||
|
"""Convert sanitized prompt inputs into a stable hashable representation.
|
||||||
|
|
||||||
|
Preserves built-in container type and stops safely on cycles or excessive depth.
|
||||||
|
"""
|
||||||
|
if depth >= max_depth:
|
||||||
|
return Unhashable()
|
||||||
|
|
||||||
|
if seen is None:
|
||||||
|
seen = set()
|
||||||
|
|
||||||
|
# Restrict recursion to plain built-in containers. Some custom nodes insert
|
||||||
|
# runtime objects into prompt inputs for dynamic graph paths; walking those
|
||||||
|
# objects as generic Mappings / Sequences is unsafe and can destabilize the
|
||||||
|
# cache signature builder.
|
||||||
|
obj_type = type(obj)
|
||||||
|
if obj_type in (int, float, str, bool, bytes, type(None)):
|
||||||
|
return obj
|
||||||
|
|
||||||
|
if obj_type in (dict, list, tuple, set, frozenset):
|
||||||
|
obj_id = id(obj)
|
||||||
|
if obj_id in seen:
|
||||||
|
return Unhashable()
|
||||||
|
seen = seen | {obj_id}
|
||||||
|
|
||||||
|
if obj_type is dict:
|
||||||
|
return (
|
||||||
|
"dict",
|
||||||
|
frozenset(
|
||||||
|
(
|
||||||
|
to_hashable(k, depth + 1, max_depth, seen),
|
||||||
|
to_hashable(v, depth + 1, max_depth, seen),
|
||||||
|
)
|
||||||
|
for k, v in obj.items()
|
||||||
|
),
|
||||||
|
)
|
||||||
|
elif obj_type is list:
|
||||||
|
return ("list", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
||||||
|
elif obj_type is tuple:
|
||||||
|
return ("tuple", tuple(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
||||||
|
elif obj_type is set:
|
||||||
|
return ("set", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
||||||
|
elif obj_type is frozenset:
|
||||||
|
return ("frozenset", frozenset(to_hashable(i, depth + 1, max_depth, seen) for i in obj))
|
||||||
else:
|
else:
|
||||||
# TODO - Support other objects like tensors?
|
|
||||||
return Unhashable()
|
return Unhashable()
|
||||||
|
|
||||||
class CacheKeySetID(CacheKeySet):
|
class CacheKeySetID(CacheKeySet):
|
||||||
|
"""Cache-key strategy that keys nodes by node id and class type."""
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
"""Initialize identity-based cache keys for the supplied dynamic prompt."""
|
||||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
|
|
||||||
async def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
|
"""Populate identity-based keys for nodes that exist in the dynamic prompt."""
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
@ -80,15 +209,19 @@ class CacheKeySetID(CacheKeySet):
|
|||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
class CacheKeySetInputSignature(CacheKeySet):
|
class CacheKeySetInputSignature(CacheKeySet):
|
||||||
|
"""Cache-key strategy that hashes a node's immediate inputs plus ancestor references."""
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||||
|
"""Initialize input-signature-based cache keys for the supplied dynamic prompt."""
|
||||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.is_changed_cache = is_changed_cache
|
self.is_changed_cache = is_changed_cache
|
||||||
|
|
||||||
def include_node_id_in_input(self) -> bool:
|
def include_node_id_in_input(self) -> bool:
|
||||||
|
"""Return whether node ids should be included in computed input signatures."""
|
||||||
return False
|
return False
|
||||||
|
|
||||||
async def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
|
"""Populate input-signature-based keys for nodes in the dynamic prompt."""
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
if node_id in self.keys:
|
if node_id in self.keys:
|
||||||
continue
|
continue
|
||||||
@ -99,6 +232,7 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
async def get_node_signature(self, dynprompt, node_id):
|
async def get_node_signature(self, dynprompt, node_id):
|
||||||
|
"""Build the full cache signature for a node and its ordered ancestors."""
|
||||||
signature = []
|
signature = []
|
||||||
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
ancestors, order_mapping = self.get_ordered_ancestry(dynprompt, node_id)
|
||||||
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
signature.append(await self.get_immediate_node_signature(dynprompt, node_id, order_mapping))
|
||||||
@ -107,6 +241,10 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
return to_hashable(signature)
|
return to_hashable(signature)
|
||||||
|
|
||||||
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
async def get_immediate_node_signature(self, dynprompt, node_id, ancestor_order_mapping):
|
||||||
|
"""Build the cache-signature fragment for a node's immediate inputs.
|
||||||
|
|
||||||
|
Link inputs are reduced to ancestor references, while raw values are sanitized first.
|
||||||
|
"""
|
||||||
if not dynprompt.has_node(node_id):
|
if not dynprompt.has_node(node_id):
|
||||||
# This node doesn't exist -- we can't cache it.
|
# This node doesn't exist -- we can't cache it.
|
||||||
return [float("NaN")]
|
return [float("NaN")]
|
||||||
@ -123,18 +261,20 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||||
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||||
else:
|
else:
|
||||||
signature.append((key, inputs[key]))
|
signature.append((key, _sanitize_signature_input(inputs[key])))
|
||||||
return signature
|
return signature
|
||||||
|
|
||||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||||
# deterministic based on which specific inputs the ancestor is connected by.
|
# deterministic based on which specific inputs the ancestor is connected by.
|
||||||
def get_ordered_ancestry(self, dynprompt, node_id):
|
def get_ordered_ancestry(self, dynprompt, node_id):
|
||||||
|
"""Return ancestors in deterministic traversal order and their index mapping."""
|
||||||
ancestors = []
|
ancestors = []
|
||||||
order_mapping = {}
|
order_mapping = {}
|
||||||
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
self.get_ordered_ancestry_internal(dynprompt, node_id, ancestors, order_mapping)
|
||||||
return ancestors, order_mapping
|
return ancestors, order_mapping
|
||||||
|
|
||||||
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
def get_ordered_ancestry_internal(self, dynprompt, node_id, ancestors, order_mapping):
|
||||||
|
"""Recursively collect ancestors in input order without revisiting prior nodes."""
|
||||||
if not dynprompt.has_node(node_id):
|
if not dynprompt.has_node(node_id):
|
||||||
return
|
return
|
||||||
inputs = dynprompt.get_node(node_id)["inputs"]
|
inputs = dynprompt.get_node(node_id)["inputs"]
|
||||||
|
|||||||
@ -1,11 +1,17 @@
|
|||||||
def is_link(obj):
|
def is_link(obj):
|
||||||
if not isinstance(obj, list):
|
"""Return whether obj is a plain prompt link of the form [node_id, output_index]."""
|
||||||
|
# Prompt links produced by the frontend / GraphBuilder are plain Python
|
||||||
|
# lists in the form [node_id, output_index]. Some custom-node paths can
|
||||||
|
# inject foreign runtime objects into prompt inputs during on-prompt graph
|
||||||
|
# rewriting or subgraph construction. Be strict here so cache signature
|
||||||
|
# building never tries to treat list-like proxy objects as links.
|
||||||
|
if type(obj) is not list:
|
||||||
return False
|
return False
|
||||||
if len(obj) != 2:
|
if len(obj) != 2:
|
||||||
return False
|
return False
|
||||||
if not isinstance(obj[0], str):
|
if type(obj[0]) is not str:
|
||||||
return False
|
return False
|
||||||
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
if type(obj[1]) is not int:
|
||||||
return False
|
return False
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user