mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-10 05:22:34 +08:00
Added tests + cleanup
This commit is contained in:
parent
af4d691d1f
commit
b951181123
@ -26,7 +26,6 @@ class CacheKeySet(ABC):
|
|||||||
def __init__(self, dynprompt, node_ids, is_changed):
|
def __init__(self, dynprompt, node_ids, is_changed):
|
||||||
self.keys = {}
|
self.keys = {}
|
||||||
self.subcache_keys = {}
|
self.subcache_keys = {}
|
||||||
self.clean_when = None
|
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
@ -85,7 +84,6 @@ class CacheKeySetID(CacheKeySet):
|
|||||||
def __init__(self, dynprompt, node_ids, is_changed):
|
def __init__(self, dynprompt, node_ids, is_changed):
|
||||||
super().__init__(dynprompt, node_ids, is_changed)
|
super().__init__(dynprompt, node_ids, is_changed)
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.clean_when = "before"
|
|
||||||
|
|
||||||
async def add_keys(self, node_ids):
|
async def add_keys(self, node_ids):
|
||||||
for node_id in node_ids:
|
for node_id in node_ids:
|
||||||
@ -97,12 +95,11 @@ class CacheKeySetID(CacheKeySet):
|
|||||||
self.keys[node_id] = (node_id, node["class_type"])
|
self.keys[node_id] = (node_id, node["class_type"])
|
||||||
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
self.subcache_keys[node_id] = (node_id, node["class_type"])
|
||||||
|
|
||||||
class CacheKeySetUpdatableInputSignature(CacheKeySet):
|
class CacheKeySetInputSignature(CacheKeySet):
|
||||||
def __init__(self, dynprompt, node_ids, is_changed):
|
def __init__(self, dynprompt, node_ids, is_changed):
|
||||||
super().__init__(dynprompt, node_ids, is_changed)
|
super().__init__(dynprompt, node_ids, is_changed)
|
||||||
self.dynprompt: DynamicPrompt = dynprompt
|
self.dynprompt: DynamicPrompt = dynprompt
|
||||||
self.is_changed = is_changed
|
self.is_changed = is_changed
|
||||||
self.clean_when = "after"
|
|
||||||
|
|
||||||
self.updated_node_ids = set()
|
self.updated_node_ids = set()
|
||||||
self.node_sig_cache = {}
|
self.node_sig_cache = {}
|
||||||
@ -241,13 +238,11 @@ class BasicCache:
|
|||||||
self.cache_key_set: CacheKeySet
|
self.cache_key_set: CacheKeySet
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.subcaches = {}
|
self.subcaches = {}
|
||||||
self.clean_when = "before"
|
|
||||||
|
|
||||||
async def set_prompt(self, dynprompt, node_ids, is_changed):
|
async def set_prompt(self, dynprompt, node_ids, is_changed):
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed)
|
self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed)
|
||||||
await self.cache_key_set.add_keys(node_ids)
|
await self.cache_key_set.add_keys(node_ids)
|
||||||
self.clean_when = self.cache_key_set.clean_when or "before"
|
|
||||||
self.is_changed = is_changed
|
self.is_changed = is_changed
|
||||||
self.initialized = True
|
self.initialized = True
|
||||||
|
|
||||||
|
|||||||
24
execution.py
24
execution.py
@ -18,7 +18,7 @@ import nodes
|
|||||||
from comfy_execution.caching import (
|
from comfy_execution.caching import (
|
||||||
BasicCache,
|
BasicCache,
|
||||||
CacheKeySetID,
|
CacheKeySetID,
|
||||||
CacheKeySetUpdatableInputSignature,
|
CacheKeySetInputSignature,
|
||||||
NullCache,
|
NullCache,
|
||||||
HierarchicalCache,
|
HierarchicalCache,
|
||||||
LRUCache,
|
LRUCache,
|
||||||
@ -113,15 +113,15 @@ class CacheSet:
|
|||||||
|
|
||||||
# Performs like the old cache -- dump data ASAP
|
# Performs like the old cache -- dump data ASAP
|
||||||
def init_classic_cache(self):
|
def init_classic_cache(self):
|
||||||
self.outputs = HierarchicalCache(CacheKeySetUpdatableInputSignature)
|
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_lru_cache(self, cache_size):
|
def init_lru_cache(self, cache_size):
|
||||||
self.outputs = LRUCache(CacheKeySetUpdatableInputSignature, max_size=cache_size)
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_ram_cache(self, min_headroom):
|
def init_ram_cache(self, min_headroom):
|
||||||
self.outputs = RAMPressureCache(CacheKeySetUpdatableInputSignature)
|
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_null_cache(self):
|
def init_null_cache(self):
|
||||||
@ -575,9 +575,7 @@ async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheS
|
|||||||
cached_outputs.append((True, node_outputs))
|
cached_outputs.append((True, node_outputs))
|
||||||
new_node_ids = set(new_node_ids)
|
new_node_ids = set(new_node_ids)
|
||||||
for cache in caches.all:
|
for cache in caches.all:
|
||||||
subcache = await cache.ensure_subcache_for(unique_id, new_node_ids)
|
await cache.ensure_subcache_for(unique_id, new_node_ids)
|
||||||
if subcache.clean_when == "before":
|
|
||||||
subcache.clean_unused()
|
|
||||||
for node_id in new_output_ids:
|
for node_id in new_output_ids:
|
||||||
execution_list.add_node(node_id)
|
execution_list.add_node(node_id)
|
||||||
execution_list.cache_link(node_id, unique_id)
|
execution_list.cache_link(node_id, unique_id)
|
||||||
@ -706,15 +704,6 @@ class PromptExecutor:
|
|||||||
is_changed = IsChanged(prompt_id, dynamic_prompt, execution_list, extra_data)
|
is_changed = IsChanged(prompt_id, dynamic_prompt, execution_list, extra_data)
|
||||||
for cache in self.caches.all:
|
for cache in self.caches.all:
|
||||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed)
|
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed)
|
||||||
if cache.clean_when == "before":
|
|
||||||
cache.clean_unused()
|
|
||||||
|
|
||||||
if self.caches.outputs.clean_when == "before":
|
|
||||||
cached_nodes = []
|
|
||||||
for node_id in prompt:
|
|
||||||
if self.caches.outputs.get(node_id) is not None:
|
|
||||||
cached_nodes.append(node_id)
|
|
||||||
self.add_message("execution_cached", {"nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False)
|
|
||||||
|
|
||||||
comfy.model_management.cleanup_models_gc()
|
comfy.model_management.cleanup_models_gc()
|
||||||
pending_subgraph_results = {}
|
pending_subgraph_results = {}
|
||||||
@ -760,8 +749,7 @@ class PromptExecutor:
|
|||||||
comfy.model_management.unload_all_models()
|
comfy.model_management.unload_all_models()
|
||||||
|
|
||||||
for cache in self.caches.all:
|
for cache in self.caches.all:
|
||||||
if cache.clean_when == "after":
|
cache.clean_unused()
|
||||||
cache.clean_unused()
|
|
||||||
|
|
||||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||||
unique_id = item
|
unique_id = item
|
||||||
|
|||||||
@ -552,27 +552,47 @@ class TestExecution:
|
|||||||
assert len(images1) == 1, "Should have 1 image"
|
assert len(images1) == 1, "Should have 1 image"
|
||||||
assert len(images2) == 1, "Should have 1 image"
|
assert len(images2) == 1, "Should have 1 image"
|
||||||
|
|
||||||
# This tests that only constant outputs are used in the call to `IS_CHANGED`
|
def test_is_changed_passed_cached_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||||
def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server):
|
|
||||||
g = builder
|
g = builder
|
||||||
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1)
|
||||||
test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5)
|
test_node = g.node("TestIsChangedWithAllInputs", image=input1.out(0), value=0.5)
|
||||||
|
|
||||||
output = g.node("PreviewImage", images=test_node.out(0))
|
output = g.node("PreviewImage", images=test_node.out(0))
|
||||||
|
|
||||||
result = client.run(g)
|
result1 = client.run(g)
|
||||||
images = result.get_images(output)
|
images = result1.get_images(output)
|
||||||
assert len(images) == 1, "Should have 1 image"
|
assert len(images) == 1, "Should have 1 image"
|
||||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
|
|
||||||
result = client.run(g)
|
result2 = client.run(g)
|
||||||
images = result.get_images(output)
|
images = result2.get_images(output)
|
||||||
assert len(images) == 1, "Should have 1 image"
|
assert len(images) == 1, "Should have 1 image"
|
||||||
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25"
|
||||||
|
|
||||||
if server["should_cache_results"]:
|
if server["should_cache_results"]:
|
||||||
assert not result.did_run(test_node), "The execution should have been cached"
|
assert not result2.did_run(test_node), "Test node should not have run again"
|
||||||
else:
|
else:
|
||||||
assert result.did_run(test_node), "The execution should have been re-run"
|
assert result2.did_run(test_node), "Test node should always run here"
|
||||||
|
|
||||||
|
def test_dont_always_run_downstream(self, client: ComfyClient, builder: GraphBuilder, server):
|
||||||
|
g = builder
|
||||||
|
float1 = g.node("TestDontAlwaysRunDownstream", float=0.5) # IS_CHANGED returns float("NaN")
|
||||||
|
image1 = g.node("StubConstantImage", value=float1.out(0), height=512, width=512, batch_size=1)
|
||||||
|
output = g.node("PreviewImage", images=image1.out(0))
|
||||||
|
|
||||||
|
result1 = client.run(g)
|
||||||
|
images = result1.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50"
|
||||||
|
|
||||||
|
result2 = client.run(g)
|
||||||
|
images = result2.get_images(output)
|
||||||
|
assert len(images) == 1, "Should have 1 image"
|
||||||
|
assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50"
|
||||||
|
|
||||||
|
if server["should_cache_results"]:
|
||||||
|
assert not result2.did_run(output), "Output node should not have run the second time"
|
||||||
|
else:
|
||||||
|
assert result2.did_run(output), "Output node should always run here"
|
||||||
|
|
||||||
|
|
||||||
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks):
|
||||||
|
|||||||
@ -100,7 +100,7 @@ class TestCustomIsChanged:
|
|||||||
else:
|
else:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
class TestIsChangedWithConstants:
|
class TestIsChangedWithAllInputs:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls):
|
def INPUT_TYPES(cls):
|
||||||
return {
|
return {
|
||||||
@ -120,10 +120,29 @@ class TestIsChangedWithConstants:
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def IS_CHANGED(cls, image, value):
|
def IS_CHANGED(cls, image, value):
|
||||||
if image is None:
|
# if image is None then an exception is thrown and is_changed becomes float("NaN")
|
||||||
return value
|
return image.mean().item() * value
|
||||||
else:
|
|
||||||
return image.mean().item() * value
|
class TestDontAlwaysRunDownstream:
|
||||||
|
@classmethod
|
||||||
|
def INPUT_TYPES(cls):
|
||||||
|
return {
|
||||||
|
"required": {
|
||||||
|
"float": ("FLOAT",),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
RETURN_TYPES = ("FLOAT",)
|
||||||
|
FUNCTION = "always_run"
|
||||||
|
|
||||||
|
CATEGORY = "Testing/Nodes"
|
||||||
|
|
||||||
|
def always_run(self, float):
|
||||||
|
return (float,)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def IS_CHANGED(cls, *args, **kwargs):
|
||||||
|
return float("NaN")
|
||||||
|
|
||||||
class TestCustomValidation1:
|
class TestCustomValidation1:
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -486,7 +505,8 @@ TEST_NODE_CLASS_MAPPINGS = {
|
|||||||
"TestLazyMixImages": TestLazyMixImages,
|
"TestLazyMixImages": TestLazyMixImages,
|
||||||
"TestVariadicAverage": TestVariadicAverage,
|
"TestVariadicAverage": TestVariadicAverage,
|
||||||
"TestCustomIsChanged": TestCustomIsChanged,
|
"TestCustomIsChanged": TestCustomIsChanged,
|
||||||
"TestIsChangedWithConstants": TestIsChangedWithConstants,
|
"TestIsChangedWithAllInputs": TestIsChangedWithAllInputs,
|
||||||
|
"TestDontAlwaysRunDownstream": TestDontAlwaysRunDownstream,
|
||||||
"TestCustomValidation1": TestCustomValidation1,
|
"TestCustomValidation1": TestCustomValidation1,
|
||||||
"TestCustomValidation2": TestCustomValidation2,
|
"TestCustomValidation2": TestCustomValidation2,
|
||||||
"TestCustomValidation3": TestCustomValidation3,
|
"TestCustomValidation3": TestCustomValidation3,
|
||||||
@ -504,7 +524,8 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
|||||||
"TestLazyMixImages": "Lazy Mix Images",
|
"TestLazyMixImages": "Lazy Mix Images",
|
||||||
"TestVariadicAverage": "Variadic Average",
|
"TestVariadicAverage": "Variadic Average",
|
||||||
"TestCustomIsChanged": "Custom IsChanged",
|
"TestCustomIsChanged": "Custom IsChanged",
|
||||||
"TestIsChangedWithConstants": "IsChanged With Constants",
|
"TestIsChangedWithAllInputs": "IsChanged With All Inputs",
|
||||||
|
"TestDontAlwaysRunDownstream": "Dont Always Run Downstream",
|
||||||
"TestCustomValidation1": "Custom Validation 1",
|
"TestCustomValidation1": "Custom Validation 1",
|
||||||
"TestCustomValidation2": "Custom Validation 2",
|
"TestCustomValidation2": "Custom Validation 2",
|
||||||
"TestCustomValidation3": "Custom Validation 3",
|
"TestCustomValidation3": "Custom Validation 3",
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user