mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Fix ui output for duplicated nodes
This commit is contained in:
parent
afa4c7b260
commit
8d17f3c7bf
@ -122,13 +122,6 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
order_mapping[ancestor_id] = len(ancestors) - 1
|
||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||
|
||||
class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature):
|
||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
||||
|
||||
def include_node_id_in_input(self):
|
||||
return True
|
||||
|
||||
class BasicCache:
|
||||
def __init__(self, key_class):
|
||||
self.key_class = key_class
|
||||
@ -151,10 +144,8 @@ class BasicCache:
|
||||
node_ids = node_ids.union(subcache.all_node_ids())
|
||||
return node_ids
|
||||
|
||||
def clean_unused(self):
|
||||
assert self.initialized
|
||||
def _clean_cache(self):
|
||||
preserve_keys = set(self.cache_key_set.get_used_keys())
|
||||
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
||||
to_remove = []
|
||||
for key in self.cache:
|
||||
if key not in preserve_keys:
|
||||
@ -162,6 +153,9 @@ class BasicCache:
|
||||
for key in to_remove:
|
||||
del self.cache[key]
|
||||
|
||||
def _clean_subcaches(self):
|
||||
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
||||
|
||||
to_remove = []
|
||||
for key in self.subcaches:
|
||||
if key not in preserve_subcaches:
|
||||
@ -169,6 +163,11 @@ class BasicCache:
|
||||
for key in to_remove:
|
||||
del self.subcaches[key]
|
||||
|
||||
def clean_unused(self):
|
||||
assert self.initialized
|
||||
self._clean_cache()
|
||||
self._clean_subcaches()
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
@ -246,15 +245,6 @@ class HierarchicalCache(BasicCache):
|
||||
assert cache is not None
|
||||
return cache._ensure_subcache(node_id, children_ids)
|
||||
|
||||
def all_active_values(self):
|
||||
active_nodes = self.all_node_ids()
|
||||
result = []
|
||||
for node_id in active_nodes:
|
||||
value = self.get(node_id)
|
||||
if value is not None:
|
||||
result.append(value)
|
||||
return result
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
super().__init__(key_class)
|
||||
@ -279,6 +269,7 @@ class LRUCache(BasicCache):
|
||||
del self.used_generation[key]
|
||||
if key in self.children:
|
||||
del self.children[key]
|
||||
self._clean_subcaches()
|
||||
|
||||
def get(self, node_id):
|
||||
self._mark_used(node_id)
|
||||
@ -294,6 +285,9 @@ class LRUCache(BasicCache):
|
||||
return self._set_immediate(node_id, value)
|
||||
|
||||
def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
super()._ensure_subcache(node_id, children_ids)
|
||||
|
||||
self.cache_key_set.add_keys(children_ids)
|
||||
self._mark_used(node_id)
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
@ -303,15 +297,3 @@ class LRUCache(BasicCache):
|
||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||
return self
|
||||
|
||||
def all_active_values(self):
|
||||
explored = set()
|
||||
to_explore = set(self.cache_key_set.get_used_keys())
|
||||
while len(to_explore) > 0:
|
||||
cache_key = to_explore.pop()
|
||||
if cache_key not in explored:
|
||||
self.used_generation[cache_key] = self.generation
|
||||
explored.add(cache_key)
|
||||
if cache_key in self.children:
|
||||
to_explore.update(self.children[cache_key])
|
||||
return [self.cache[key] for key in explored if key in self.cache]
|
||||
|
||||
|
||||
16
execution.py
16
execution.py
@ -15,7 +15,7 @@ import comfy.model_management
|
||||
import comfy.graph_utils
|
||||
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy.graph_utils import is_link, GraphBuilder
|
||||
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID
|
||||
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy.cli_args import args
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@ -69,13 +69,13 @@ class CacheSet:
|
||||
# blowing away the cache every time
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size)
|
||||
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID)
|
||||
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def recursive_debug_dump(self):
|
||||
@ -486,10 +486,12 @@ class PromptExecutor:
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for ui_info in self.caches.ui.all_active_values():
|
||||
node_id = ui_info["meta"]["node_id"]
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
all_node_ids = self.caches.ui.all_node_ids()
|
||||
for node_id in all_node_ids:
|
||||
ui_info = self.caches.ui.get(node_id)
|
||||
if ui_info is not None:
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
|
||||
@ -117,16 +117,26 @@ class TestExecution:
|
||||
#
|
||||
# Initialize server and client
|
||||
#
|
||||
@fixture(scope="class", autouse=True)
|
||||
def _server(self, args_pytest):
|
||||
@fixture(scope="class", autouse=True, params=[
|
||||
# (use_lru, lru_size)
|
||||
(False, 0),
|
||||
(True, 0),
|
||||
(True, 100),
|
||||
])
|
||||
def _server(self, args_pytest, request):
|
||||
# Start server
|
||||
p = subprocess.Popen([
|
||||
'python','main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
])
|
||||
pargs = [
|
||||
'python','main.py',
|
||||
'--output-directory', args_pytest["output_dir"],
|
||||
'--listen', args_pytest["listen"],
|
||||
'--port', str(args_pytest["port"]),
|
||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||
]
|
||||
use_lru, lru_size = request.param
|
||||
if use_lru:
|
||||
pargs += ['--cache-lru', str(lru_size)]
|
||||
print("Running server with args:", pargs)
|
||||
p = subprocess.Popen(pargs)
|
||||
yield
|
||||
p.kill()
|
||||
torch.cuda.empty_cache()
|
||||
@ -159,15 +169,9 @@ class TestExecution:
|
||||
shared_client.set_test_name(f"execution[{request.node.name}]")
|
||||
yield shared_client
|
||||
|
||||
def clear_cache(self, client: ComfyClient):
|
||||
g = GraphBuilder(prefix="foo")
|
||||
random = g.node("StubImage", content="NOISE", height=1, width=1, batch_size=1)
|
||||
g.node("PreviewImage", images=random.out(0))
|
||||
client.run(g)
|
||||
|
||||
@fixture
|
||||
def builder(self):
|
||||
yield GraphBuilder(prefix="")
|
||||
def builder(self, request):
|
||||
yield GraphBuilder(prefix=request.node.name)
|
||||
|
||||
def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
@ -187,7 +191,6 @@ class TestExecution:
|
||||
assert result.did_run(lazy_mix)
|
||||
|
||||
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||
self.clear_cache(client)
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
@ -196,14 +199,12 @@ class TestExecution:
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
result1 = client.run(g)
|
||||
client.run(g)
|
||||
result2 = client.run(g)
|
||||
for node_id, node in g.nodes.items():
|
||||
assert result1.did_run(node), f"Node {node_id} didn't run"
|
||||
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
||||
|
||||
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||
self.clear_cache(client)
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||
@ -212,15 +213,11 @@ class TestExecution:
|
||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||
g.node("SaveImage", images=lazy_mix.out(0))
|
||||
|
||||
result1 = client.run(g)
|
||||
client.run(g)
|
||||
mask.inputs['value'] = 0.4
|
||||
result2 = client.run(g)
|
||||
for node_id, node in g.nodes.items():
|
||||
assert result1.did_run(node), f"Node {node_id} didn't run"
|
||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||
assert result2.did_run(mask), "Mask should have been re-run"
|
||||
assert result2.did_run(lazy_mix), "Lazy mix should have been re-run"
|
||||
|
||||
def test_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
@ -365,7 +362,6 @@ class TestExecution:
|
||||
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
||||
|
||||
def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||
self.clear_cache(client)
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||
@ -378,8 +374,6 @@ class TestExecution:
|
||||
result_image = result.get_images(output)[0]
|
||||
expected = 255 // 4
|
||||
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
||||
assert result.did_run(input1)
|
||||
assert result.did_run(input2)
|
||||
|
||||
def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
@ -418,3 +412,17 @@ class TestExecution:
|
||||
assert len(images_literal) == 3, "Should have 2 images"
|
||||
for i in range(3):
|
||||
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
|
||||
|
||||
def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder):
|
||||
g = builder
|
||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||
|
||||
output1 = g.node("PreviewImage", images=input1.out(0))
|
||||
output2 = g.node("PreviewImage", images=input1.out(0))
|
||||
|
||||
result = client.run(g)
|
||||
images1 = result.get_images(output1)
|
||||
images2 = result.get_images(output2)
|
||||
assert len(images1) == 1, "Should have 1 image"
|
||||
assert len(images2) == 1, "Should have 1 image"
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user