diff --git a/comfy_api/latest/_caching.py b/comfy_api/latest/_caching.py index 686c99969..30c8848cd 100644 --- a/comfy_api/latest/_caching.py +++ b/comfy_api/latest/_caching.py @@ -5,7 +5,6 @@ from dataclasses import dataclass @dataclass class CacheContext: - prompt_id: str node_id: str class_type: str cache_key_hash: str # SHA256 hex digest diff --git a/comfy_execution/cache_provider.py b/comfy_execution/cache_provider.py index 20bb51b7c..d455d08e8 100644 --- a/comfy_execution/cache_provider.py +++ b/comfy_execution/cache_provider.py @@ -73,9 +73,12 @@ def _canonicalize(obj: Any) -> Any: elif isinstance(obj, list): return [_canonicalize(item) for item in obj] elif isinstance(obj, dict): - return {str(k): _canonicalize(v) for k, v in sorted(obj.items())} + return {"__dict__": sorted( + [[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()], + key=lambda x: json.dumps(x, sort_keys=True) + )} elif isinstance(obj, (int, float, str, bool, type(None))): - return obj + return (type(obj).__name__, obj) elif isinstance(obj, bytes): return ("__bytes__", obj.hex()) else: diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index e11990feb..c5782d44a 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -156,8 +156,6 @@ class BasicCache: self.cache = {} self.subcaches = {} - self._current_prompt_id = '' - async def set_prompt(self, dynprompt, node_ids, is_changed_cache): self.dynprompt = dynprompt self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) @@ -319,7 +317,6 @@ class BasicCache: if cache_key_hash is None: return None return CacheContext( - prompt_id=self._current_prompt_id, node_id=node_id, class_type=self._get_class_type(node_id), cache_key_hash=cache_key_hash, @@ -333,7 +330,6 @@ class BasicCache: subcache = self.subcaches.get(subcache_key, None) if subcache is None: subcache = BasicCache(self.key_class) - subcache._current_prompt_id = self._current_prompt_id self.subcaches[subcache_key] = subcache await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) return subcache diff --git a/execution.py b/execution.py index 6d963d640..3d80606b4 100644 --- a/execution.py +++ b/execution.py @@ -714,9 +714,6 @@ class PromptExecutor: self.status_messages = [] self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) - for cache in self.caches.all: - cache._current_prompt_id = prompt_id - self._notify_prompt_lifecycle("start", prompt_id) try: diff --git a/tests-unit/execution_test/test_cache_provider.py b/tests-unit/execution_test/test_cache_provider.py index ccbad88ce..ac3814746 100644 --- a/tests-unit/execution_test/test_cache_provider.py +++ b/tests-unit/execution_test/test_cache_provider.py @@ -69,25 +69,28 @@ class TestCanonicalize: result = _canonicalize(t) assert result[0] == "__tuple__" - assert result[1] == [1, 2, 3] def test_list_preserved(self): """Lists should be recursively canonicalized.""" lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] result = _canonicalize(lst) - # First element should be dict with sorted keys - assert result[0] == {"a": 1, "b": 2} + # First element should be canonicalized dict + assert "__dict__" in result[0] # Second element should be canonicalized frozenset assert result[1][0] == "__frozenset__" - def test_primitives_unchanged(self): - """Primitive types should pass through unchanged.""" - assert _canonicalize(42) == 42 - assert _canonicalize(3.14) == 3.14 - assert _canonicalize("hello") == "hello" - assert _canonicalize(True) is True - assert _canonicalize(None) is None + def test_primitives_include_type(self): + """Primitive types should include type name for disambiguation.""" + assert _canonicalize(42) == ("int", 42) + assert _canonicalize(3.14) == ("float", 3.14) + assert _canonicalize("hello") == ("str", "hello") + assert _canonicalize(True) == ("bool", True) + assert _canonicalize(None) == ("NoneType", None) + + def test_int_and_str_distinguished(self): + """int 7 and str '7' must produce different canonical forms.""" + assert _canonicalize(7) != _canonicalize("7") def test_bytes_converted(self): """Bytes should be converted to hex string.""" @@ -364,13 +367,11 @@ class TestCacheContext: def test_context_creation(self): """CacheContext should be created with all fields.""" context = CacheContext( - prompt_id="prompt-123", node_id="node-456", class_type="KSampler", cache_key_hash="a" * 64, ) - assert context.prompt_id == "prompt-123" assert context.node_id == "node-456" assert context.class_type == "KSampler" assert context.cache_key_hash == "a" * 64