mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-21 11:50:16 +08:00
Fix ui output for duplicated nodes
This commit is contained in:
parent
afa4c7b260
commit
8d17f3c7bf
@ -122,13 +122,6 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
order_mapping[ancestor_id] = len(ancestors) - 1
|
order_mapping[ancestor_id] = len(ancestors) - 1
|
||||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||||
|
|
||||||
class CacheKeySetInputSignatureWithID(CacheKeySetInputSignature):
|
|
||||||
def __init__(self, dynprompt, node_ids, is_changed_cache):
|
|
||||||
super().__init__(dynprompt, node_ids, is_changed_cache)
|
|
||||||
|
|
||||||
def include_node_id_in_input(self):
|
|
||||||
return True
|
|
||||||
|
|
||||||
class BasicCache:
|
class BasicCache:
|
||||||
def __init__(self, key_class):
|
def __init__(self, key_class):
|
||||||
self.key_class = key_class
|
self.key_class = key_class
|
||||||
@ -151,10 +144,8 @@ class BasicCache:
|
|||||||
node_ids = node_ids.union(subcache.all_node_ids())
|
node_ids = node_ids.union(subcache.all_node_ids())
|
||||||
return node_ids
|
return node_ids
|
||||||
|
|
||||||
def clean_unused(self):
|
def _clean_cache(self):
|
||||||
assert self.initialized
|
|
||||||
preserve_keys = set(self.cache_key_set.get_used_keys())
|
preserve_keys = set(self.cache_key_set.get_used_keys())
|
||||||
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
|
||||||
to_remove = []
|
to_remove = []
|
||||||
for key in self.cache:
|
for key in self.cache:
|
||||||
if key not in preserve_keys:
|
if key not in preserve_keys:
|
||||||
@ -162,6 +153,9 @@ class BasicCache:
|
|||||||
for key in to_remove:
|
for key in to_remove:
|
||||||
del self.cache[key]
|
del self.cache[key]
|
||||||
|
|
||||||
|
def _clean_subcaches(self):
|
||||||
|
preserve_subcaches = set(self.cache_key_set.get_used_subcache_keys())
|
||||||
|
|
||||||
to_remove = []
|
to_remove = []
|
||||||
for key in self.subcaches:
|
for key in self.subcaches:
|
||||||
if key not in preserve_subcaches:
|
if key not in preserve_subcaches:
|
||||||
@ -169,6 +163,11 @@ class BasicCache:
|
|||||||
for key in to_remove:
|
for key in to_remove:
|
||||||
del self.subcaches[key]
|
del self.subcaches[key]
|
||||||
|
|
||||||
|
def clean_unused(self):
|
||||||
|
assert self.initialized
|
||||||
|
self._clean_cache()
|
||||||
|
self._clean_subcaches()
|
||||||
|
|
||||||
def _set_immediate(self, node_id, value):
|
def _set_immediate(self, node_id, value):
|
||||||
assert self.initialized
|
assert self.initialized
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
@ -246,15 +245,6 @@ class HierarchicalCache(BasicCache):
|
|||||||
assert cache is not None
|
assert cache is not None
|
||||||
return cache._ensure_subcache(node_id, children_ids)
|
return cache._ensure_subcache(node_id, children_ids)
|
||||||
|
|
||||||
def all_active_values(self):
|
|
||||||
active_nodes = self.all_node_ids()
|
|
||||||
result = []
|
|
||||||
for node_id in active_nodes:
|
|
||||||
value = self.get(node_id)
|
|
||||||
if value is not None:
|
|
||||||
result.append(value)
|
|
||||||
return result
|
|
||||||
|
|
||||||
class LRUCache(BasicCache):
|
class LRUCache(BasicCache):
|
||||||
def __init__(self, key_class, max_size=100):
|
def __init__(self, key_class, max_size=100):
|
||||||
super().__init__(key_class)
|
super().__init__(key_class)
|
||||||
@ -279,6 +269,7 @@ class LRUCache(BasicCache):
|
|||||||
del self.used_generation[key]
|
del self.used_generation[key]
|
||||||
if key in self.children:
|
if key in self.children:
|
||||||
del self.children[key]
|
del self.children[key]
|
||||||
|
self._clean_subcaches()
|
||||||
|
|
||||||
def get(self, node_id):
|
def get(self, node_id):
|
||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
@ -294,6 +285,9 @@ class LRUCache(BasicCache):
|
|||||||
return self._set_immediate(node_id, value)
|
return self._set_immediate(node_id, value)
|
||||||
|
|
||||||
def ensure_subcache_for(self, node_id, children_ids):
|
def ensure_subcache_for(self, node_id, children_ids):
|
||||||
|
# Just uses subcaches for tracking 'live' nodes
|
||||||
|
super()._ensure_subcache(node_id, children_ids)
|
||||||
|
|
||||||
self.cache_key_set.add_keys(children_ids)
|
self.cache_key_set.add_keys(children_ids)
|
||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
@ -303,15 +297,3 @@ class LRUCache(BasicCache):
|
|||||||
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
self.children[cache_key].append(self.cache_key_set.get_data_key(child_id))
|
||||||
return self
|
return self
|
||||||
|
|
||||||
def all_active_values(self):
|
|
||||||
explored = set()
|
|
||||||
to_explore = set(self.cache_key_set.get_used_keys())
|
|
||||||
while len(to_explore) > 0:
|
|
||||||
cache_key = to_explore.pop()
|
|
||||||
if cache_key not in explored:
|
|
||||||
self.used_generation[cache_key] = self.generation
|
|
||||||
explored.add(cache_key)
|
|
||||||
if cache_key in self.children:
|
|
||||||
to_explore.update(self.children[cache_key])
|
|
||||||
return [self.cache[key] for key in explored if key in self.cache]
|
|
||||||
|
|
||||||
|
|||||||
16
execution.py
16
execution.py
@ -15,7 +15,7 @@ import comfy.model_management
|
|||||||
import comfy.graph_utils
|
import comfy.graph_utils
|
||||||
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
from comfy.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||||
from comfy.graph_utils import is_link, GraphBuilder
|
from comfy.graph_utils import is_link, GraphBuilder
|
||||||
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetInputSignatureWithID, CacheKeySetID
|
from comfy.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@ -69,13 +69,13 @@ class CacheSet:
|
|||||||
# blowing away the cache every time
|
# blowing away the cache every time
|
||||||
def init_lru_cache(self, cache_size):
|
def init_lru_cache(self, cache_size):
|
||||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
self.ui = LRUCache(CacheKeySetInputSignatureWithID, max_size=cache_size)
|
self.ui = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
# Performs like the old cache -- dump data ASAP
|
# Performs like the old cache -- dump data ASAP
|
||||||
def init_classic_cache(self):
|
def init_classic_cache(self):
|
||||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
self.ui = HierarchicalCache(CacheKeySetInputSignatureWithID)
|
self.ui = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def recursive_debug_dump(self):
|
def recursive_debug_dump(self):
|
||||||
@ -486,10 +486,12 @@ class PromptExecutor:
|
|||||||
|
|
||||||
ui_outputs = {}
|
ui_outputs = {}
|
||||||
meta_outputs = {}
|
meta_outputs = {}
|
||||||
for ui_info in self.caches.ui.all_active_values():
|
all_node_ids = self.caches.ui.all_node_ids()
|
||||||
node_id = ui_info["meta"]["node_id"]
|
for node_id in all_node_ids:
|
||||||
ui_outputs[node_id] = ui_info["output"]
|
ui_info = self.caches.ui.get(node_id)
|
||||||
meta_outputs[node_id] = ui_info["meta"]
|
if ui_info is not None:
|
||||||
|
ui_outputs[node_id] = ui_info["output"]
|
||||||
|
meta_outputs[node_id] = ui_info["meta"]
|
||||||
self.history_result = {
|
self.history_result = {
|
||||||
"outputs": ui_outputs,
|
"outputs": ui_outputs,
|
||||||
"meta": meta_outputs,
|
"meta": meta_outputs,
|
||||||
|
|||||||
@ -117,16 +117,26 @@ class TestExecution:
|
|||||||
#
|
#
|
||||||
# Initialize server and client
|
# Initialize server and client
|
||||||
#
|
#
|
||||||
@fixture(scope="class", autouse=True)
|
@fixture(scope="class", autouse=True, params=[
|
||||||
def _server(self, args_pytest):
|
# (use_lru, lru_size)
|
||||||
|
(False, 0),
|
||||||
|
(True, 0),
|
||||||
|
(True, 100),
|
||||||
|
])
|
||||||
|
def _server(self, args_pytest, request):
|
||||||
# Start server
|
# Start server
|
||||||
p = subprocess.Popen([
|
pargs = [
|
||||||
'python','main.py',
|
'python','main.py',
|
||||||
'--output-directory', args_pytest["output_dir"],
|
'--output-directory', args_pytest["output_dir"],
|
||||||
'--listen', args_pytest["listen"],
|
'--listen', args_pytest["listen"],
|
||||||
'--port', str(args_pytest["port"]),
|
'--port', str(args_pytest["port"]),
|
||||||
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
'--extra-model-paths-config', 'tests/inference/extra_model_paths.yaml',
|
||||||
])
|
]
|
||||||
|
use_lru, lru_size = request.param
|
||||||
|
if use_lru:
|
||||||
|
pargs += ['--cache-lru', str(lru_size)]
|
||||||
|
print("Running server with args:", pargs)
|
||||||
|
p = subprocess.Popen(pargs)
|
||||||
yield
|
yield
|
||||||
p.kill()
|
p.kill()
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -159,15 +169,9 @@ class TestExecution:
|
|||||||
shared_client.set_test_name(f"execution[{request.node.name}]")
|
shared_client.set_test_name(f"execution[{request.node.name}]")
|
||||||
yield shared_client
|
yield shared_client
|
||||||
|
|
||||||
def clear_cache(self, client: ComfyClient):
|
|
||||||
g = GraphBuilder(prefix="foo")
|
|
||||||
random = g.node("StubImage", content="NOISE", height=1, width=1, batch_size=1)
|
|
||||||
g.node("PreviewImage", images=random.out(0))
|
|
||||||
client.run(g)
|
|
||||||
|
|
||||||
@fixture
|
@fixture
|
||||||
def builder(self):
|
def builder(self, request):
|
||||||
yield GraphBuilder(prefix="")
|
yield GraphBuilder(prefix=request.node.name)
|
||||||
|
|
||||||
def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
|
def test_lazy_input(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
@ -187,7 +191,6 @@ class TestExecution:
|
|||||||
assert result.did_run(lazy_mix)
|
assert result.did_run(lazy_mix)
|
||||||
|
|
||||||
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
|
def test_full_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
self.clear_cache(client)
|
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||||
@ -196,14 +199,12 @@ class TestExecution:
|
|||||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||||
g.node("SaveImage", images=lazy_mix.out(0))
|
g.node("SaveImage", images=lazy_mix.out(0))
|
||||||
|
|
||||||
result1 = client.run(g)
|
client.run(g)
|
||||||
result2 = client.run(g)
|
result2 = client.run(g)
|
||||||
for node_id, node in g.nodes.items():
|
for node_id, node in g.nodes.items():
|
||||||
assert result1.did_run(node), f"Node {node_id} didn't run"
|
|
||||||
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
assert not result2.did_run(node), f"Node {node_id} ran, but should have been cached"
|
||||||
|
|
||||||
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
|
def test_partial_cache(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
self.clear_cache(client)
|
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="NOISE", height=512, width=512, batch_size=1)
|
||||||
@ -212,15 +213,11 @@ class TestExecution:
|
|||||||
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
lazy_mix = g.node("TestLazyMixImages", image1=input1.out(0), image2=input2.out(0), mask=mask.out(0))
|
||||||
g.node("SaveImage", images=lazy_mix.out(0))
|
g.node("SaveImage", images=lazy_mix.out(0))
|
||||||
|
|
||||||
result1 = client.run(g)
|
client.run(g)
|
||||||
mask.inputs['value'] = 0.4
|
mask.inputs['value'] = 0.4
|
||||||
result2 = client.run(g)
|
result2 = client.run(g)
|
||||||
for node_id, node in g.nodes.items():
|
|
||||||
assert result1.did_run(node), f"Node {node_id} didn't run"
|
|
||||||
assert not result2.did_run(input1), "Input1 should have been cached"
|
assert not result2.did_run(input1), "Input1 should have been cached"
|
||||||
assert not result2.did_run(input2), "Input2 should have been cached"
|
assert not result2.did_run(input2), "Input2 should have been cached"
|
||||||
assert result2.did_run(mask), "Mask should have been re-run"
|
|
||||||
assert result2.did_run(lazy_mix), "Lazy mix should have been re-run"
|
|
||||||
|
|
||||||
def test_error(self, client: ComfyClient, builder: GraphBuilder):
|
def test_error(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
@ -365,7 +362,6 @@ class TestExecution:
|
|||||||
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
assert result4.did_run(is_changed), "is_changed should not have been cached"
|
||||||
|
|
||||||
def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
def test_undeclared_inputs(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
self.clear_cache(client)
|
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
input2 = g.node("StubImage", content="WHITE", height=512, width=512, batch_size=1)
|
||||||
@ -378,8 +374,6 @@ class TestExecution:
|
|||||||
result_image = result.get_images(output)[0]
|
result_image = result.get_images(output)[0]
|
||||||
expected = 255 // 4
|
expected = 255 // 4
|
||||||
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
assert numpy.array(result_image).min() == expected and numpy.array(result_image).max() == expected, "Image should be grey"
|
||||||
assert result.did_run(input1)
|
|
||||||
assert result.did_run(input2)
|
|
||||||
|
|
||||||
def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
|
def test_for_loop(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
g = builder
|
g = builder
|
||||||
@ -418,3 +412,17 @@ class TestExecution:
|
|||||||
assert len(images_literal) == 3, "Should have 2 images"
|
assert len(images_literal) == 3, "Should have 2 images"
|
||||||
for i in range(3):
|
for i in range(3):
|
||||||
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
|
assert numpy.array(images_literal[i]).min() == 255 and numpy.array(images_literal[i]).max() == 255, "All images should be white"
|
||||||
|
|
||||||
|
def test_output_reuse(self, client: ComfyClient, builder: GraphBuilder):
|
||||||
|
g = builder
|
||||||
|
input1 = g.node("StubImage", content="BLACK", height=512, width=512, batch_size=1)
|
||||||
|
|
||||||
|
output1 = g.node("PreviewImage", images=input1.out(0))
|
||||||
|
output2 = g.node("PreviewImage", images=input1.out(0))
|
||||||
|
|
||||||
|
result = client.run(g)
|
||||||
|
images1 = result.get_images(output1)
|
||||||
|
images2 = result.get_images(output2)
|
||||||
|
assert len(images1) == 1, "Should have 1 image"
|
||||||
|
assert len(images2) == 1, "Should have 1 image"
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user