diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index f1b5227db..caa5d4a48 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -252,6 +252,7 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): memo = {} active = set() + snapshots = {} sort_memo = {} processed = 0 stack = [(obj, False)] @@ -263,12 +264,12 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): return value return memo.get(id(value), Unhashable()) - def resolve_unordered_values(current, container_tag): + def resolve_unordered_values(current_items, container_tag): """Resolve a set-like container or fail closed if ordering is ambiguous.""" try: ordered_items = [ (_sanitized_sort_key(item, memo=sort_memo), resolve_value(item)) - for item in current + for item in current_items ] ordered_items.sort(key=lambda item: item[0]) except RuntimeError: @@ -300,18 +301,33 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): active.discard(current_id) try: if current_type is dict: + items = snapshots.pop(current_id, None) + if items is None: + items = list(current.items()) memo[current_id] = ( "dict", - tuple((resolve_value(k), resolve_value(v)) for k, v in current.items()), + tuple((resolve_value(k), resolve_value(v)) for k, v in items), ) elif current_type is list: - memo[current_id] = ("list", tuple(resolve_value(item) for item in current)) + items = snapshots.pop(current_id, None) + if items is None: + items = list(current) + memo[current_id] = ("list", tuple(resolve_value(item) for item in items)) elif current_type is tuple: - memo[current_id] = ("tuple", tuple(resolve_value(item) for item in current)) + items = snapshots.pop(current_id, None) + if items is None: + items = list(current) + memo[current_id] = ("tuple", tuple(resolve_value(item) for item in items)) elif current_type is set: - memo[current_id] = resolve_unordered_values(current, "set") + items = snapshots.pop(current_id, None) + if items is None: + items = list(current) + memo[current_id] = resolve_unordered_values(items, "set") else: - memo[current_id] = resolve_unordered_values(current, "frozenset") + items = snapshots.pop(current_id, None) + if items is None: + items = list(current) + memo[current_id] = resolve_unordered_values(items, "frozenset") except RuntimeError: memo[current_id] = Unhashable() continue @@ -329,6 +345,7 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): if current_type is dict: try: items = list(current.items()) + snapshots[current_id] = items except RuntimeError: memo[current_id] = Unhashable() active.discard(current_id) @@ -339,6 +356,7 @@ def to_hashable(obj, max_nodes=_MAX_SIGNATURE_CONTAINER_VISITS): else: try: items = list(current) + snapshots[current_id] = items except RuntimeError: memo[current_id] = Unhashable() active.discard(current_id)