diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 1ca1edcc0..2169dda9a 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -154,42 +154,57 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti active.add(obj_id) try: if obj_type is dict: - sort_memo = {} - sanitized_items = [ - ( - _sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget), - _sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget), - ) - for key, value in obj.items() - ] - ordered_items = [ - ( + try: + sort_memo = {} + sanitized_items = [ ( - _sanitized_sort_key(key, depth + 1, max_depth, memo=sort_memo), - _sanitized_sort_key(value, depth + 1, max_depth, memo=sort_memo), - ), - (key, value), - ) - for key, value in sanitized_items - ] - ordered_items.sort(key=lambda item: item[0]) + _sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget), + _sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget), + ) + for key, value in obj.items() + ] + ordered_items = [ + ( + ( + _sanitized_sort_key(key, depth + 1, max_depth, memo=sort_memo), + _sanitized_sort_key(value, depth + 1, max_depth, memo=sort_memo), + ), + (key, value), + ) + for key, value in sanitized_items + ] + ordered_items.sort(key=lambda item: item[0]) - result = Unhashable() - for index in range(1, len(ordered_items)): - previous_sort_key, previous_item = ordered_items[index - 1] - current_sort_key, current_item = ordered_items[index] - if previous_sort_key == current_sort_key and previous_item != current_item: - break - else: - result = {key: value for _, (key, value) in ordered_items} + result = Unhashable() + for index in range(1, len(ordered_items)): + previous_sort_key, previous_item = ordered_items[index - 1] + current_sort_key, current_item = ordered_items[index] + if previous_sort_key == current_sort_key and previous_item != current_item: + break + else: + result = {key: value for _, (key, value) in ordered_items} + except RuntimeError: + result = Unhashable() elif obj_type is list: - result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj] + try: + result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj] + except RuntimeError: + result = Unhashable() elif obj_type is tuple: - result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj) + try: + result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj) + except RuntimeError: + result = Unhashable() elif obj_type is set: - result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj} + try: + result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj} + except RuntimeError: + result = Unhashable() else: - result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj) + try: + result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj) + except RuntimeError: + result = Unhashable() finally: active.discard(obj_id) @@ -226,11 +241,14 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): def resolve_unordered_values(current, container_tag): """Resolve a set-like container or fail closed if ordering is ambiguous.""" - ordered_items = [ - (_sanitized_sort_key(item, memo=sort_memo), resolve_value(item)) - for item in current - ] - ordered_items.sort(key=lambda item: item[0]) + try: + ordered_items = [ + (_sanitized_sort_key(item, memo=sort_memo), resolve_value(item)) + for item in current + ] + ordered_items.sort(key=lambda item: item[0]) + except RuntimeError: + return Unhashable() for index in range(1, len(ordered_items)): previous_key, previous_value = ordered_items[index - 1] @@ -256,19 +274,22 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): if expanded: active.discard(current_id) - if current_type is dict: - memo[current_id] = ( - "dict", - tuple((resolve_value(k), resolve_value(v)) for k, v in current.items()), - ) - elif current_type is list: - memo[current_id] = ("list", tuple(resolve_value(item) for item in current)) - elif current_type is tuple: - memo[current_id] = ("tuple", tuple(resolve_value(item) for item in current)) - elif current_type is set: - memo[current_id] = resolve_unordered_values(current, "set") - else: - memo[current_id] = resolve_unordered_values(current, "frozenset") + try: + if current_type is dict: + memo[current_id] = ( + "dict", + tuple((resolve_value(k), resolve_value(v)) for k, v in current.items()), + ) + elif current_type is list: + memo[current_id] = ("list", tuple(resolve_value(item) for item in current)) + elif current_type is tuple: + memo[current_id] = ("tuple", tuple(resolve_value(item) for item in current)) + elif current_type is set: + memo[current_id] = resolve_unordered_values(current, "set") + else: + memo[current_id] = resolve_unordered_values(current, "frozenset") + except RuntimeError: + memo[current_id] = Unhashable() continue if current_id in active: @@ -282,12 +303,22 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): active.add(current_id) stack.append((current, True)) if current_type is dict: - items = list(current.items()) + try: + items = list(current.items()) + except RuntimeError: + memo[current_id] = Unhashable() + active.discard(current_id) + continue for key, value in reversed(items): stack.append((value, False)) stack.append((key, False)) else: - items = list(current) + try: + items = list(current) + except RuntimeError: + memo[current_id] = Unhashable() + active.discard(current_id) + continue for item in reversed(items): stack.append((item, False)) diff --git a/tests-unit/execution_test/caching_test.py b/tests-unit/execution_test/caching_test.py index 2f088722e..c9892304a 100644 --- a/tests-unit/execution_test/caching_test.py +++ b/tests-unit/execution_test/caching_test.py @@ -105,6 +105,35 @@ def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_m assert sanitized[0][1]["value"] == 2 +@pytest.mark.parametrize( + "container_factory", + [ + lambda marker: [marker], + lambda marker: (marker,), + lambda marker: {marker}, + lambda marker: frozenset({marker}), + lambda marker: {marker: "value"}, + ], +) +def test_sanitize_signature_input_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory): + """Traversal RuntimeError should degrade sanitization to Unhashable.""" + caching, _ = caching_module + original = caching._sanitize_signature_input + marker = object() + + def raising_sanitize(obj, *args, **kwargs): + """Raise a traversal RuntimeError for the marker value and delegate otherwise.""" + if obj is marker: + raise RuntimeError("container changed during iteration") + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_sanitize_signature_input", raising_sanitize) + + sanitized = original(container_factory(marker)) + + assert isinstance(sanitized, caching.Unhashable) + + def test_to_hashable_handles_shared_builtin_substructures(caching_module): """Repeated sanitized content should hash stably for shared substructures.""" caching, _ = caching_module @@ -118,6 +147,28 @@ def test_to_hashable_handles_shared_builtin_substructures(caching_module): assert hashable[1][0][0] == "list" +@pytest.mark.parametrize( + "container_factory", + [ + set, + frozenset, + ], +) +def test_to_hashable_fails_closed_on_runtimeerror(caching_module, monkeypatch, container_factory): + """Traversal RuntimeError should degrade unordered hash conversion to Unhashable.""" + caching, _ = caching_module + + def raising_sort_key(obj, *args, **kwargs): + """Raise a traversal RuntimeError while unordered values are canonicalized.""" + raise RuntimeError("container changed during iteration") + + monkeypatch.setattr(caching, "_sanitized_sort_key", raising_sort_key) + + hashable = caching.to_hashable(container_factory({"value"})) + + assert isinstance(hashable, caching.Unhashable) + + def test_sanitize_signature_input_fails_closed_for_ambiguous_dict_ordering(caching_module): """Ambiguous dict sort ties should fail closed instead of depending on input order.""" caching, _ = caching_module