This commit is contained in:
xmarre 2026-03-14 05:36:12 +00:00 committed by GitHub
commit 072fc014c7
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 96 additions and 15 deletions

View File

@ -1,7 +1,6 @@
import asyncio
import bisect
import gc
import itertools
import psutil
import time
import torch
@ -51,17 +50,94 @@ class Unhashable:
def __init__(self):
self.value = float("NaN")
def to_hashable(obj):
# So that we don't infinitely recurse since frozenset and tuples
# are Sequences.
if isinstance(obj, (int, float, str, bool, bytes, type(None))):
return obj
elif isinstance(obj, Mapping):
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items())])
elif isinstance(obj, Sequence):
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj]))
def _sanitized_sort_key(obj, depth=0, max_depth=32):
if depth >= max_depth:
return ("MAX_DEPTH",)
obj_type = type(obj)
if obj_type is Unhashable:
return ("UNHASHABLE",)
elif obj_type in (int, float, str, bool, bytes, type(None)):
return (obj_type.__module__, obj_type.__qualname__, repr(obj))
elif obj_type is dict:
items = [
(
_sanitized_sort_key(k, depth + 1, max_depth),
_sanitized_sort_key(v, depth + 1, max_depth),
)
for k, v in obj.items()
]
items.sort()
return ("dict", tuple(items))
elif obj_type is list:
return ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
elif obj_type is tuple:
return ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
elif obj_type is set:
return ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
elif obj_type is frozenset:
return ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
else:
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
def _sanitize_signature_input(obj, depth=0, max_depth=32):
if depth >= max_depth:
return Unhashable()
obj_type = type(obj)
if obj_type in (int, float, str, bool, bytes, type(None)):
return obj
elif obj_type is dict:
sanitized_items = [
(
_sanitize_signature_input(key, depth + 1, max_depth),
_sanitize_signature_input(value, depth + 1, max_depth),
)
for key, value in obj.items()
]
sanitized_items.sort(
key=lambda kv: (
_sanitized_sort_key(kv[0], depth + 1, max_depth),
_sanitized_sort_key(kv[1], depth + 1, max_depth),
)
)
return {key: value for key, value in sanitized_items}
elif obj_type is list:
return [_sanitize_signature_input(item, depth + 1, max_depth) for item in obj]
elif obj_type is tuple:
return tuple(_sanitize_signature_input(item, depth + 1, max_depth) for item in obj)
elif obj_type is set:
return {_sanitize_signature_input(item, depth + 1, max_depth) for item in obj}
elif obj_type is frozenset:
return frozenset(_sanitize_signature_input(item, depth + 1, max_depth) for item in obj)
else:
# Execution-cache signatures should be built from prompt-safe values.
# If a custom node injects a runtime object here, mark it unhashable so
# the node won't reuse stale cache entries across runs, but do not walk
# the foreign object and risk crashing on custom container semantics.
return Unhashable()
def to_hashable(obj):
# Restrict recursion to plain built-in containers. Some custom nodes insert
# runtime objects into prompt inputs for dynamic graph paths; walking those
# objects as generic Mappings / Sequences is unsafe and can destabilize the
# cache signature builder.
obj_type = type(obj)
if obj_type in (int, float, str, bool, bytes, type(None)):
return obj
elif obj_type is dict:
return ("dict", frozenset((to_hashable(k), to_hashable(v)) for k, v in obj.items()))
elif obj_type is list:
return ("list", tuple(to_hashable(i) for i in obj))
elif obj_type is tuple:
return ("tuple", tuple(to_hashable(i) for i in obj))
elif obj_type is set:
return ("set", frozenset(to_hashable(i) for i in obj))
elif obj_type is frozenset:
return ("frozenset", frozenset(to_hashable(i) for i in obj))
else:
# TODO - Support other objects like tensors?
return Unhashable()
class CacheKeySetID(CacheKeySet):
@ -123,7 +199,7 @@ class CacheKeySetInputSignature(CacheKeySet):
ancestor_index = ancestor_order_mapping[ancestor_id]
signature.append((key,("ANCESTOR", ancestor_index, ancestor_socket)))
else:
signature.append((key, inputs[key]))
signature.append((key, _sanitize_signature_input(inputs[key])))
return signature
# This function returns a list of all ancestors of the given node. The order of the list is

View File

@ -1,11 +1,16 @@
def is_link(obj):
if not isinstance(obj, list):
# Prompt links produced by the frontend / GraphBuilder are plain Python
# lists in the form [node_id, output_index]. Some custom-node paths can
# inject foreign runtime objects into prompt inputs during on-prompt graph
# rewriting or subgraph construction. Be strict here so cache signature
# building never tries to treat list-like proxy objects as links.
if type(obj) is not list:
return False
if len(obj) != 2:
return False
if not isinstance(obj[0], str):
if type(obj[0]) is not str:
return False
if not isinstance(obj[1], int) and not isinstance(obj[1], float):
if type(obj[1]) is not int:
return False
return True