diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 2169dda9a..08f3f436b 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -155,13 +155,14 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti try: if obj_type is dict: try: + items = list(obj.items()) sort_memo = {} sanitized_items = [ ( _sanitize_signature_input(key, depth + 1, max_depth, active, memo, budget), _sanitize_signature_input(value, depth + 1, max_depth, active, memo, budget), ) - for key, value in obj.items() + for key, value in items ] ordered_items = [ ( @@ -187,22 +188,26 @@ def _sanitize_signature_input(obj, depth=0, max_depth=_MAX_SIGNATURE_DEPTH, acti result = Unhashable() elif obj_type is list: try: - result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj] + items = list(obj) + result = [_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items] except RuntimeError: result = Unhashable() elif obj_type is tuple: try: - result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj) + items = list(obj) + result = tuple(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items) except RuntimeError: result = Unhashable() elif obj_type is set: try: - result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj} + items = list(obj) + result = {_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items} except RuntimeError: result = Unhashable() else: try: - result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in obj) + items = list(obj) + result = frozenset(_sanitize_signature_input(item, depth + 1, max_depth, active, memo, budget) for item in items) except RuntimeError: result = Unhashable() finally: diff --git a/tests-unit/execution_test/caching_test.py b/tests-unit/execution_test/caching_test.py index c9892304a..6313faed1 100644 --- a/tests-unit/execution_test/caching_test.py +++ b/tests-unit/execution_test/caching_test.py @@ -105,6 +105,48 @@ def test_sanitize_signature_input_handles_shared_builtin_substructures(caching_m assert sanitized[0][1]["value"] == 2 +def test_sanitize_signature_input_snapshots_list_before_recursing(caching_module, monkeypatch): + """List sanitization should read a point-in-time snapshot before recursive descent.""" + caching, _ = caching_module + original = caching._sanitize_signature_input + marker = object() + values = [marker, 2] + + def mutating_sanitize(obj, *args, **kwargs): + """Mutate the live list during recursion to verify snapshot-based traversal.""" + if obj is marker: + values[1] = 3 + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_sanitize_signature_input", mutating_sanitize) + + sanitized = original(values) + + assert isinstance(sanitized, list) + assert sanitized[1] == 2 + + +def test_sanitize_signature_input_snapshots_dict_before_recursing(caching_module, monkeypatch): + """Dict sanitization should read a point-in-time snapshot before recursive descent.""" + caching, _ = caching_module + original = caching._sanitize_signature_input + marker = object() + values = {"first": marker, "second": 2} + + def mutating_sanitize(obj, *args, **kwargs): + """Mutate the live dict during recursion to verify snapshot-based traversal.""" + if obj is marker: + values["second"] = 3 + return original(obj, *args, **kwargs) + + monkeypatch.setattr(caching, "_sanitize_signature_input", mutating_sanitize) + + sanitized = original(values) + + assert isinstance(sanitized, dict) + assert sanitized["second"] == 2 + + @pytest.mark.parametrize( "container_factory", [