mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-17 07:05:12 +08:00
Sanitize execution cache inputs
This commit is contained in:
parent
e1f10ca093
commit
7d76a4447e
@ -51,17 +51,48 @@ class Unhashable:
|
||||
def __init__(self):
|
||||
self.value = float("NaN")
|
||||
|
||||
def to_hashable(obj):
|
||||
# So that we don't infinitely recurse since frozenset and tuples
|
||||
# are Sequences.
|
||||
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
|
||||
|
||||
def _sanitize_signature_input(obj, depth=0, max_depth=32):
|
||||
if depth >= max_depth:
|
||||
return Unhashable()
|
||||
|
||||
obj_type = type(obj)
|
||||
if obj_type in (int, float, str, bool, bytes, type(None)):
|
||||
return obj
|
||||
elif isinstance(obj, Mapping):
|
||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
|
||||
elif isinstance(obj, Sequence):
|
||||
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
||||
elif obj_type is dict:
|
||||
sanitized = []
|
||||
for key in sorted(obj.keys(), key=lambda x: (type(x).__module__, type(x).__qualname__, repr(x))):
|
||||
sanitized.append((_sanitize_signature_input(key, depth + 1, max_depth),
|
||||
_sanitize_signature_input(obj[key], depth + 1, max_depth)))
|
||||
return tuple(sanitized)
|
||||
elif obj_type in (list, tuple):
|
||||
return tuple(_sanitize_signature_input(item, depth + 1, max_depth) for item in obj)
|
||||
elif obj_type in (set, frozenset):
|
||||
sanitized_items = [_sanitize_signature_input(item, depth + 1, max_depth) for item in obj]
|
||||
sanitized_items.sort(key=repr)
|
||||
return tuple(sanitized_items)
|
||||
else:
|
||||
# Execution-cache signatures should be built from prompt-safe values.
|
||||
# If a custom node injects a runtime object here, mark it unhashable so
|
||||
# the node won't reuse stale cache entries across runs, but do not walk
|
||||
# the foreign object and risk crashing on custom container semantics.
|
||||
return Unhashable()
|
||||
|
||||
def to_hashable(obj):
|
||||
# Restrict recursion to plain built-in containers. Some custom nodes insert
|
||||
# runtime objects into prompt inputs for dynamic graph paths; walking those
|
||||
# objects as generic Mappings / Sequences is unsafe and can destabilize the
|
||||
# cache signature builder.
|
||||
obj_type = type(obj)
|
||||
if obj_type in (int, float, str, bool, bytes, type(None)):
|
||||
return obj
|
||||
elif obj_type is dict:
|
||||
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items(), key=lambda kv: repr(kv[0]))])
|
||||
elif obj_type in (list, tuple):
|
||||
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
|
||||
elif obj_type in (set, frozenset):
|
||||
return frozenset([to_hashable(i) for i in obj])
|
||||
else:
|
||||
# TODO - Support other objects like tensors?
|
||||
return Unhashable()
|
||||
|
||||
class CacheKeySetID(CacheKeySet):
|
||||
@ -123,7 +154,7 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
ancestor_index = ancestor_order_mapping[ancestor_id]
|
||||
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
|
||||
else:
|
||||
signature.append((key, inputs[key]))
|
||||
signature.append((key, _sanitize_signature_input(inputs[key])))
|
||||
return signature
|
||||
|
||||
# This function returns a list of all ancestors of the given node. The order of the list is
|
||||
|
||||
@ -1,11 +1,16 @@
|
||||
def is_link(obj):
|
||||
if not isinstance(obj, list):
|
||||
# Prompt links produced by the frontend / GraphBuilder are plain Python
|
||||
# lists in the form [node_id, output_index]. Some custom-node paths can
|
||||
# inject foreign runtime objects into prompt inputs during on-prompt graph
|
||||
# rewriting or subgraph construction. Be strict here so cache signature
|
||||
# building never tries to treat list-like proxy objects as links.
|
||||
if type(obj) is not list:
|
||||
return False
|
||||
if len(obj) != 2:
|
||||
return False
|
||||
if not isinstance(obj[0], str):
|
||||
if type(obj[0]) is not str:
|
||||
return False
|
||||
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
|
||||
if type(obj[1]) not in (int, float):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user