Keep container types in sanitizer

This commit is contained in:
xmarre 2026-03-14 06:36:06 +01:00
parent 0c1bfad0df
commit 2adde5a0e1
2 changed files with 62 additions and 17 deletions

View File

@ -1,7 +1,6 @@
import asyncio import asyncio
import bisect import bisect
import gc import gc
import itertools
import psutil import psutil
import time import time
import torch import torch
@ -52,6 +51,37 @@ class Unhashable:
self.value = float("NaN") self.value = float("NaN")
def _sanitized_sort_key(obj, depth=0, max_depth=32):
if depth >= max_depth:
return ("MAX_DEPTH",)
obj_type = type(obj)
if obj_type is Unhashable:
return ("UNHASHABLE",)
elif obj_type in (int, float, str, bool, bytes, type(None)):
return (obj_type.__module__, obj_type.__qualname__, repr(obj))
elif obj_type is dict:
items = [
(
_sanitized_sort_key(k, depth + 1, max_depth),
_sanitized_sort_key(v, depth + 1, max_depth),
)
for k, v in obj.items()
]
items.sort()
return ("dict", tuple(items))
elif obj_type is list:
return ("list", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
elif obj_type is tuple:
return ("tuple", tuple(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj))
elif obj_type is set:
return ("set", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
elif obj_type is frozenset:
return ("frozenset", tuple(sorted(_sanitized_sort_key(i, depth + 1, max_depth) for i in obj)))
else:
return (obj_type.__module__, obj_type.__qualname__, "OPAQUE")
def _sanitize_signature_input(obj, depth=0, max_depth=32): def _sanitize_signature_input(obj, depth=0, max_depth=32):
if depth >= max_depth: if depth >= max_depth:
return Unhashable() return Unhashable()
@ -60,17 +90,28 @@ def _sanitize_signature_input(obj, depth=0, max_depth=32):
if obj_type in (int, float, str, bool, bytes, type(None)): if obj_type in (int, float, str, bool, bytes, type(None)):
return obj return obj
elif obj_type is dict: elif obj_type is dict:
sanitized = [] sanitized_items = [
for key in sorted(obj.keys(), key=lambda x: (type(x).__module__, type(x).__qualname__, repr(x))): (
sanitized.append((_sanitize_signature_input(key, depth + 1, max_depth), _sanitize_signature_input(key, depth + 1, max_depth),
_sanitize_signature_input(obj[key], depth + 1, max_depth))) _sanitize_signature_input(value, depth + 1, max_depth),
return tuple(sanitized) )
elif obj_type in (list, tuple): for key, value in obj.items()
]
sanitized_items.sort(
key=lambda kv: (
_sanitized_sort_key(kv[0], depth + 1, max_depth),
_sanitized_sort_key(kv[1], depth + 1, max_depth),
)
)
return {key: value for key, value in sanitized_items}
elif obj_type is list:
return [_sanitize_signature_input(item, depth + 1, max_depth) for item in obj]
elif obj_type is tuple:
return tuple(_sanitize_signature_input(item, depth + 1, max_depth) for item in obj) return tuple(_sanitize_signature_input(item, depth + 1, max_depth) for item in obj)
elif obj_type in (set, frozenset): elif obj_type is set:
sanitized_items = [_sanitize_signature_input(item, depth + 1, max_depth) for item in obj] return {_sanitize_signature_input(item, depth + 1, max_depth) for item in obj}
sanitized_items.sort(key=repr) elif obj_type is frozenset:
return tuple(sanitized_items) return frozenset(_sanitize_signature_input(item, depth + 1, max_depth) for item in obj)
else: else:
# Execution-cache signatures should be built from prompt-safe values. # Execution-cache signatures should be built from prompt-safe values.
# If a custom node injects a runtime object here, mark it unhashable so # If a custom node injects a runtime object here, mark it unhashable so
@ -87,11 +128,15 @@ def to_hashable(obj):
if obj_type in (int, float, str, bool, bytes, type(None)): if obj_type in (int, float, str, bool, bytes, type(None)):
return obj return obj
elif obj_type is dict: elif obj_type is dict:
return frozenset([(to_hashable(k), to_hashable(v)) for k, v in sorted(obj.items(), key=lambda kv: repr(kv[0]))]) return ("dict", frozenset((to_hashable(k), to_hashable(v)) for k, v in obj.items()))
elif obj_type in (list, tuple): elif obj_type is list:
return frozenset(zip(itertools.count(), [to_hashable(i) for i in obj])) return ("list", tuple(to_hashable(i) for i in obj))
elif obj_type in (set, frozenset): elif obj_type is tuple:
return frozenset([to_hashable(i) for i in obj]) return ("tuple", tuple(to_hashable(i) for i in obj))
elif obj_type is set:
return ("set", frozenset(to_hashable(i) for i in obj))
elif obj_type is frozenset:
return ("frozenset", frozenset(to_hashable(i) for i in obj))
else: else:
return Unhashable() return Unhashable()

View File

@ -10,7 +10,7 @@ def is_link(obj):
return False return False
if type(obj[0]) is not str: if type(obj[0]) is not str:
return False return False
if type(obj[1]) not in (int, float): if type(obj[1]) is not int:
return False return False
return True return True