mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-14 21:57:33 +08:00
fix: remove prompt_id from CacheContext, type-safe canonicalization
Remove prompt_id from CacheContext — it's not relevant for cache matching and added unnecessary plumbing (_current_prompt_id on every cache). Lifecycle hooks still receive prompt_id directly. Include type name in canonicalized primitives so that int 7 and str "7" produce distinct hashes. Also canonicalize dict keys properly instead of str() coercion. Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
parent
476538ad60
commit
832d3ef4a6
@ -5,7 +5,6 @@ from dataclasses import dataclass
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class CacheContext:
|
class CacheContext:
|
||||||
prompt_id: str
|
|
||||||
node_id: str
|
node_id: str
|
||||||
class_type: str
|
class_type: str
|
||||||
cache_key_hash: str # SHA256 hex digest
|
cache_key_hash: str # SHA256 hex digest
|
||||||
|
|||||||
@ -73,9 +73,12 @@ def _canonicalize(obj: Any) -> Any:
|
|||||||
elif isinstance(obj, list):
|
elif isinstance(obj, list):
|
||||||
return [_canonicalize(item) for item in obj]
|
return [_canonicalize(item) for item in obj]
|
||||||
elif isinstance(obj, dict):
|
elif isinstance(obj, dict):
|
||||||
return {str(k): _canonicalize(v) for k, v in sorted(obj.items())}
|
return {"__dict__": sorted(
|
||||||
|
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
|
||||||
|
key=lambda x: json.dumps(x, sort_keys=True)
|
||||||
|
)}
|
||||||
elif isinstance(obj, (int, float, str, bool, type(None))):
|
elif isinstance(obj, (int, float, str, bool, type(None))):
|
||||||
return obj
|
return (type(obj).__name__, obj)
|
||||||
elif isinstance(obj, bytes):
|
elif isinstance(obj, bytes):
|
||||||
return ("__bytes__", obj.hex())
|
return ("__bytes__", obj.hex())
|
||||||
else:
|
else:
|
||||||
|
|||||||
@ -156,8 +156,6 @@ class BasicCache:
|
|||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.subcaches = {}
|
self.subcaches = {}
|
||||||
|
|
||||||
self._current_prompt_id = ''
|
|
||||||
|
|
||||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed_cache)
|
||||||
@ -319,7 +317,6 @@ class BasicCache:
|
|||||||
if cache_key_hash is None:
|
if cache_key_hash is None:
|
||||||
return None
|
return None
|
||||||
return CacheContext(
|
return CacheContext(
|
||||||
prompt_id=self._current_prompt_id,
|
|
||||||
node_id=node_id,
|
node_id=node_id,
|
||||||
class_type=self._get_class_type(node_id),
|
class_type=self._get_class_type(node_id),
|
||||||
cache_key_hash=cache_key_hash,
|
cache_key_hash=cache_key_hash,
|
||||||
@ -333,7 +330,6 @@ class BasicCache:
|
|||||||
subcache = self.subcaches.get(subcache_key, None)
|
subcache = self.subcaches.get(subcache_key, None)
|
||||||
if subcache is None:
|
if subcache is None:
|
||||||
subcache = BasicCache(self.key_class)
|
subcache = BasicCache(self.key_class)
|
||||||
subcache._current_prompt_id = self._current_prompt_id
|
|
||||||
self.subcaches[subcache_key] = subcache
|
self.subcaches[subcache_key] = subcache
|
||||||
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
await subcache.set_prompt(self.dynprompt, children_ids, self.is_changed_cache)
|
||||||
return subcache
|
return subcache
|
||||||
|
|||||||
@ -714,9 +714,6 @@ class PromptExecutor:
|
|||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
for cache in self.caches.all:
|
|
||||||
cache._current_prompt_id = prompt_id
|
|
||||||
|
|
||||||
self._notify_prompt_lifecycle("start", prompt_id)
|
self._notify_prompt_lifecycle("start", prompt_id)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@ -69,25 +69,28 @@ class TestCanonicalize:
|
|||||||
result = _canonicalize(t)
|
result = _canonicalize(t)
|
||||||
|
|
||||||
assert result[0] == "__tuple__"
|
assert result[0] == "__tuple__"
|
||||||
assert result[1] == [1, 2, 3]
|
|
||||||
|
|
||||||
def test_list_preserved(self):
|
def test_list_preserved(self):
|
||||||
"""Lists should be recursively canonicalized."""
|
"""Lists should be recursively canonicalized."""
|
||||||
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
|
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
|
||||||
result = _canonicalize(lst)
|
result = _canonicalize(lst)
|
||||||
|
|
||||||
# First element should be dict with sorted keys
|
# First element should be canonicalized dict
|
||||||
assert result[0] == {"a": 1, "b": 2}
|
assert "__dict__" in result[0]
|
||||||
# Second element should be canonicalized frozenset
|
# Second element should be canonicalized frozenset
|
||||||
assert result[1][0] == "__frozenset__"
|
assert result[1][0] == "__frozenset__"
|
||||||
|
|
||||||
def test_primitives_unchanged(self):
|
def test_primitives_include_type(self):
|
||||||
"""Primitive types should pass through unchanged."""
|
"""Primitive types should include type name for disambiguation."""
|
||||||
assert _canonicalize(42) == 42
|
assert _canonicalize(42) == ("int", 42)
|
||||||
assert _canonicalize(3.14) == 3.14
|
assert _canonicalize(3.14) == ("float", 3.14)
|
||||||
assert _canonicalize("hello") == "hello"
|
assert _canonicalize("hello") == ("str", "hello")
|
||||||
assert _canonicalize(True) is True
|
assert _canonicalize(True) == ("bool", True)
|
||||||
assert _canonicalize(None) is None
|
assert _canonicalize(None) == ("NoneType", None)
|
||||||
|
|
||||||
|
def test_int_and_str_distinguished(self):
|
||||||
|
"""int 7 and str '7' must produce different canonical forms."""
|
||||||
|
assert _canonicalize(7) != _canonicalize("7")
|
||||||
|
|
||||||
def test_bytes_converted(self):
|
def test_bytes_converted(self):
|
||||||
"""Bytes should be converted to hex string."""
|
"""Bytes should be converted to hex string."""
|
||||||
@ -364,13 +367,11 @@ class TestCacheContext:
|
|||||||
def test_context_creation(self):
|
def test_context_creation(self):
|
||||||
"""CacheContext should be created with all fields."""
|
"""CacheContext should be created with all fields."""
|
||||||
context = CacheContext(
|
context = CacheContext(
|
||||||
prompt_id="prompt-123",
|
|
||||||
node_id="node-456",
|
node_id="node-456",
|
||||||
class_type="KSampler",
|
class_type="KSampler",
|
||||||
cache_key_hash="a" * 64,
|
cache_key_hash="a" * 64,
|
||||||
)
|
)
|
||||||
|
|
||||||
assert context.prompt_id == "prompt-123"
|
|
||||||
assert context.node_id == "node-456"
|
assert context.node_id == "node-456"
|
||||||
assert context.class_type == "KSampler"
|
assert context.class_type == "KSampler"
|
||||||
assert context.cache_key_hash == "a" * 64
|
assert context.cache_key_hash == "a" * 64
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user