Add docstrings and harden signature

This commit is contained in:
xmarre 2026-03-15 02:55:39 +01:00
parent aceaa5e579
commit 117afbc1d7
2 changed files with 132 additions and 50 deletions

View File

@ -154,42 +154,57 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti
active.add(obj_id) active.add(obj_id)
try: try:
if obj_type is dict: if obj_type is dict:
sort_memo = {} try:
sanitized_items = [ sort_memo = {}
( sanitized_items = [
_sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget),
_sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget),
)
for key, value in obj.items()
]
ordered_items = [
(
( (
_sanitized_sort_key(key, depth + 1, max_depth, memo=sort_memo), _sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget),
_sanitized_sort_key(value, depth + 1, max_depth, memo=sort_memo), _sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget),
), )
(key, value), for key, value in obj.items()
) ]
for key, value in sanitized_items ordered_items = [
] (
ordered_items.sort(key=lambda item: item[0]) (
_sanitized_sort_key(key, depth + 1, max_depth, memo=sort_memo),
_sanitized_sort_key(value, depth + 1, max_depth, memo=sort_memo),
),
(key, value),
)
for key, value in sanitized_items
]
ordered_items.sort(key=lambda item: item[0])
result = Unhashable() result = Unhashable()
for index in range(1, len(ordered_items)): for index in range(1, len(ordered_items)):
previous_sort_key, previous_item = ordered_items[index - 1] previous_sort_key, previous_item = ordered_items[index - 1]
current_sort_key, current_item = ordered_items[index] current_sort_key, current_item = ordered_items[index]
if previous_sort_key == current_sort_key and previous_item != current_item: if previous_sort_key == current_sort_key and previous_item != current_item:
break break
else: else:
result = {key: value for _, (key, value) in ordered_items} result = {key: value for _, (key, value) in ordered_items}
except RuntimeError:
result = Unhashable()
elif obj_type is list: elif obj_type is list:
result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj] try:
result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj]
except RuntimeError:
result = Unhashable()
elif obj_type is tuple: elif obj_type is tuple:
result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj) try:
result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj)
except RuntimeError:
result = Unhashable()
elif obj_type is set: elif obj_type is set:
result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj} try:
result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj}
except RuntimeError:
result = Unhashable()
else: else:
result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj) try:
result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj)
except RuntimeError:
result = Unhashable()
finally: finally:
active.discard(obj_id) active.discard(obj_id)
@ -226,11 +241,14 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
def resolve_unordered_values(current, container_tag): def resolve_unordered_values(current, container_tag):
"""Resolve a set-like container or fail closed if ordering is ambiguous.""" """Resolve a set-like container or fail closed if ordering is ambiguous."""
ordered_items = [ try:
(_sanitized_sort_key(item, memo=sort_memo), resolve_value(item)) ordered_items = [
for item in current (_sanitized_sort_key(item, memo=sort_memo), resolve_value(item))
] for item in current
ordered_items.sort(key=lambda item: item[0]) ]
ordered_items.sort(key=lambda item: item[0])
except RuntimeError:
return Unhashable()
for index in range(1, len(ordered_items)): for index in range(1, len(ordered_items)):
previous_key, previous_value = ordered_items[index - 1] previous_key, previous_value = ordered_items[index - 1]
@ -256,19 +274,22 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
if expanded: if expanded:
active.discard(current_id) active.discard(current_id)
if current_type is dict: try:
memo[current_id] = ( if current_type is dict:
"dict", memo[current_id] = (
tuple((resolve_value(k), resolve_value(v)) for k, v in current.items()), "dict",
) tuple((resolve_value(k), resolve_value(v)) for k, v in current.items()),
elif current_type is list: )
memo[current_id] = ("list", tuple(resolve_value(item) for item in current)) elif current_type is list:
elif current_type is tuple: memo[current_id] = ("list", tuple(resolve_value(item) for item in current))
memo[current_id] = ("tuple", tuple(resolve_value(item) for item in current)) elif current_type is tuple:
elif current_type is set: memo[current_id] = ("tuple", tuple(resolve_value(item) for item in current))
memo[current_id] = resolve_unordered_values(current, "set") elif current_type is set:
else: memo[current_id] = resolve_unordered_values(current, "set")
memo[current_id] = resolve_unordered_values(current, "frozenset") else:
memo[current_id] = resolve_unordered_values(current, "frozenset")
except RuntimeError:
memo[current_id] = Unhashable()
continue continue
if current_id in active: if current_id in active:
@ -282,12 +303,22 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS):
active.add(current_id) active.add(current_id)
stack.append((current, True)) stack.append((current, True))
if current_type is dict: if current_type is dict:
items = list(current.items()) try:
items = list(current.items())
except RuntimeError:
memo[current_id] = Unhashable()
active.discard(current_id)
continue
for key, value in reversed(items): for key, value in reversed(items):
stack.append((value, False)) stack.append((value, False))
stack.append((key, False)) stack.append((key, False))
else: else:
items = list(current) try:
items = list(current)
except RuntimeError:
memo[current_id] = Unhashable()
active.discard(current_id)
continue
for item in reversed(items): for item in reversed(items):
stack.append((item, False)) stack.append((item, False))

View File

@ -105,6 +105,35 @@ def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_m
assert sanitized[0][1]["value"] == 2 assert sanitized[0][1]["value"] == 2
@pytest.mark.parametrize(
"container_factory",
[
lambda marker: [marker],
lambda marker: (marker,),
lambda marker: {marker},
lambda marker: frozenset({marker}),
lambda marker: {marker: "value"},
],
)
def test_sanitize_signature_input_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
"""Traversal RuntimeError should degrade sanitization to Unhashable."""
caching, _ = caching_module
original = caching._sanitize_signature_input
marker = object()
def raising_sanitize(obj, *args, **kwargs):
"""Raise a traversal RuntimeError for the marker value and delegate otherwise."""
if obj is marker:
raise RuntimeError("container changed during iteration")
return original(obj, *args, **kwargs)
monkeypatch.setattr(caching, "_sanitize_signature_input", raising_sanitize)
sanitized = original(container_factory(marker))
assert isinstance(sanitized, caching.Unhashable)
def test_to_hashable_handles_shared_builtin_substructures(caching_module): def test_to_hashable_handles_shared_builtin_substructures(caching_module):
"""Repeated sanitized content should hash stably for shared substructures.""" """Repeated sanitized content should hash stably for shared substructures."""
caching, _ = caching_module caching, _ = caching_module
@ -118,6 +147,28 @@ def test_to_hashable_handles_shared_builtin_substructures(caching_module):
assert hashable[1][0][0] == "list" assert hashable[1][0][0] == "list"
@pytest.mark.parametrize(
"container_factory",
[
set,
frozenset,
],
)
def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory):
"""Traversal RuntimeError should degrade unordered hash conversion to Unhashable."""
caching, _ = caching_module
def raising_sort_key(obj, *args, **kwargs):
"""Raise a traversal RuntimeError while unordered values are canonicalized."""
raise RuntimeError("container changed during iteration")
monkeypatch.setattr(caching, "_sanitized_sort_key", raising_sort_key)
hashable = caching.to_hashable(container_factory({"value"}))
assert isinstance(hashable, caching.Unhashable)
def test_sanitize_signature_input_fails_closed_for_ambiguous_dict_ordering(caching_module): def test_sanitize_signature_input_fails_closed_for_ambiguous_dict_ordering(caching_module):
"""Ambiguous dict sort ties should fail closed instead of depending on input order.""" """Ambiguous dict sort ties should fail closed instead of depending on input order."""
caching, _ = caching_module caching, _ = caching_module