fix: remove prompt_id from CacheContext, type-safe canonicalization
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run

Remove prompt_id from CacheContext — it's not relevant for cache
matching and added unnecessary plumbing (_current_prompt_id on every
cache). Lifecycle hooks still receive prompt_id directly.

Include type name in canonicalized primitives so that int 7 and
str "7" produce distinct hashes. Also canonicalize dict keys properly
instead of str() coercion.

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Deep Mehta 2026-03-09 21:08:22 -07:00
parent 476538ad60
commit 832d3ef4a6
5 changed files with 18 additions and 22 deletions

View File

@ -5,7 +5,6 @@ from dataclasses import dataclass
@dataclass @dataclass
class CacheContext: class CacheContext:
prompt_id: str
node_id: str node_id: str
class_type: str class_type: str
cache_key_hash: str # SHA256 hex digest cache_key_hash: str # SHA256 hex digest

View File

@ -73,9 +73,12 @@ def _canonicalize(obj: Any) -> Any:
elif isinstance(obj, list): elif isinstance(obj, list):
return [_canonicalize(item) for item in obj] return [_canonicalize(item) for item in obj]
elif isinstance(obj, dict): elif isinstance(obj, dict):
return {str(k): _canonicalize(v) for k, v in sorted(obj.items())} return {"__dict__": sorted(
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
key=lambda x: json.dumps(x, sort_keys=True)
)}
elif isinstance(obj, (int, float, str, bool, type(None))): elif isinstance(obj, (int, float, str, bool, type(None))):
return obj return (type(obj).__name__, obj)
elif isinstance(obj, bytes): elif isinstance(obj, bytes):
return ("__bytes__", obj.hex()) return ("__bytes__", obj.hex())
else: else:

View File

@ -156,8 +156,6 @@ class BasicCache:
self.cache = {} self.cache = {}
self.subcaches = {} self.subcaches = {}
self._current_prompt_id = ''
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt self.dynprompt = dynprompt
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache) self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
@ -319,7 +317,6 @@ class BasicCache:
if cache_key_hash is None: if cache_key_hash is None:
return None return None
return CacheContext( return CacheContext(
prompt_id=self._current_prompt_id,
node_id=node_id, node_id=node_id,
class_type=self._get_class_type(node_id), class_type=self._get_class_type(node_id),
cache_key_hash=cache_key_hash, cache_key_hash=cache_key_hash,
@ -333,7 +330,6 @@ class BasicCache:
subcache = self.subcaches.get(subcache_key, None) subcache = self.subcaches.get(subcache_key, None)
if subcache is None: if subcache is None:
subcache = BasicCache(self.key_class) subcache = BasicCache(self.key_class)
subcache._current_prompt_id = self._current_prompt_id
self.subcaches[subcache_key] = subcache self.subcaches[subcache_key] = subcache
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache) await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
return subcache return subcache

View File

@ -714,9 +714,6 @@ class PromptExecutor:
self.status_messages = [] self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
for cache in self.caches.all:
cache._current_prompt_id = prompt_id
self._notify_prompt_lifecycle("start", prompt_id) self._notify_prompt_lifecycle("start", prompt_id)
try: try:

View File

@ -69,25 +69,28 @@ class TestCanonicalize:
result = _canonicalize(t) result = _canonicalize(t)
assert result[0] == "__tuple__" assert result[0] == "__tuple__"
assert result[1] == [1, 2, 3]
def test_list_preserved(self): def test_list_preserved(self):
"""Lists should be recursively canonicalized.""" """Lists should be recursively canonicalized."""
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])] lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
result = _canonicalize(lst) result = _canonicalize(lst)
# First element should be dict with sorted keys # First element should be canonicalized dict
assert result[0] == {"a": 1, "b": 2} assert "__dict__" in result[0]
# Second element should be canonicalized frozenset # Second element should be canonicalized frozenset
assert result[1][0] == "__frozenset__" assert result[1][0] == "__frozenset__"
def test_primitives_unchanged(self): def test_primitives_include_type(self):
"""Primitive types should pass through unchanged.""" """Primitive types should include type name for disambiguation."""
assert _canonicalize(42) == 42 assert _canonicalize(42) == ("int", 42)
assert _canonicalize(3.14) == 3.14 assert _canonicalize(3.14) == ("float", 3.14)
assert _canonicalize("hello") == "hello" assert _canonicalize("hello") == ("str", "hello")
assert _canonicalize(True) is True assert _canonicalize(True) == ("bool", True)
assert _canonicalize(None) is None assert _canonicalize(None) == ("NoneType", None)
def test_int_and_str_distinguished(self):
"""int 7 and str '7' must produce different canonical forms."""
assert _canonicalize(7) != _canonicalize("7")
def test_bytes_converted(self): def test_bytes_converted(self):
"""Bytes should be converted to hex string.""" """Bytes should be converted to hex string."""
@ -364,13 +367,11 @@ class TestCacheContext:
def test_context_creation(self): def test_context_creation(self):
"""CacheContext should be created with all fields.""" """CacheContext should be created with all fields."""
context = CacheContext( context = CacheContext(
prompt_id="prompt-123",
node_id="node-456", node_id="node-456",
class_type="KSampler", class_type="KSampler",
cache_key_hash="a" * 64, cache_key_hash="a" * 64,
) )
assert context.prompt_id == "prompt-123"
assert context.node_id == "node-456" assert context.node_id == "node-456"
assert context.class_type == "KSampler" assert context.class_type == "KSampler"
assert context.cache_key_hash == "a" * 64 assert context.cache_key_hash == "a" * 64