From b951181123c59d43a3b798fd8859f9d30dd02b7c Mon Sep 17 00:00:00 2001 From: Deluxe233 Date: Wed, 28 Jan 2026 10:37:10 -0500 Subject: [PATCH] Added tests + cleanup --- comfy_execution/caching.py | 7 +--- execution.py | 24 +++-------- tests/execution/test_execution.py | 40 ++++++++++++++----- .../testing-pack/specific_tests.py | 35 ++++++++++++---- 4 files changed, 65 insertions(+), 41 deletions(-) diff --git a/comfy_execution/caching.py b/comfy_execution/caching.py index 8fcd19b00..a579982a0 100644 --- a/comfy_execution/caching.py +++ b/comfy_execution/caching.py @@ -26,7 +26,6 @@ class CacheKeySet(ABC): def __init__(self, dynprompt, node_ids, is_changed): self.keys = {} self.subcache_keys = {} - self.clean_when = None @abstractmethod async def add_keys(self, node_ids): @@ -85,7 +84,6 @@ class CacheKeySetID(CacheKeySet): def __init__(self, dynprompt, node_ids, is_changed): super().__init__(dynprompt, node_ids, is_changed) self.dynprompt = dynprompt - self.clean_when = "before" async def add_keys(self, node_ids): for node_id in node_ids: @@ -97,12 +95,11 @@ class CacheKeySetID(CacheKeySet): self.keys[node_id] = (node_id, node["class_type"]) self.subcache_keys[node_id] = (node_id, node["class_type"]) -class CacheKeySetUpdatableInputSignature(CacheKeySet): +class CacheKeySetInputSignature(CacheKeySet): def __init__(self, dynprompt, node_ids, is_changed): super().__init__(dynprompt, node_ids, is_changed) self.dynprompt: DynamicPrompt = dynprompt self.is_changed = is_changed - self.clean_when = "after" self.updated_node_ids = set() self.node_sig_cache = {} @@ -241,13 +238,11 @@ class BasicCache: self.cache_key_set: CacheKeySet self.cache = {} self.subcaches = {} - self.clean_when = "before" async def set_prompt(self, dynprompt, node_ids, is_changed): self.dynprompt = dynprompt self.cache_key_set = self.key_class(dynprompt, node_ids, is_changed) await self.cache_key_set.add_keys(node_ids) - self.clean_when = self.cache_key_set.clean_when or "before" self.is_changed = is_changed self.initialized = True diff --git a/execution.py b/execution.py index 5b21a4228..4337272b0 100644 --- a/execution.py +++ b/execution.py @@ -18,7 +18,7 @@ import nodes from comfy_execution.caching import ( BasicCache, CacheKeySetID, - CacheKeySetUpdatableInputSignature, + CacheKeySetInputSignature, NullCache, HierarchicalCache, LRUCache, @@ -113,15 +113,15 @@ class CacheSet: # Performs like the old cache -- dump data ASAP def init_classic_cache(self): - self.outputs = HierarchicalCache(CacheKeySetUpdatableInputSignature) + self.outputs = HierarchicalCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_lru_cache(self, cache_size): - self.outputs = LRUCache(CacheKeySetUpdatableInputSignature, max_size=cache_size) + self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.objects = HierarchicalCache(CacheKeySetID) def init_ram_cache(self, min_headroom): - self.outputs = RAMPressureCache(CacheKeySetUpdatableInputSignature) + self.outputs = RAMPressureCache(CacheKeySetInputSignature) self.objects = HierarchicalCache(CacheKeySetID) def init_null_cache(self): @@ -575,9 +575,7 @@ async def execute(server: PromptServer, dynprompt: DynamicPrompt, caches: CacheS cached_outputs.append((True, node_outputs)) new_node_ids = set(new_node_ids) for cache in caches.all: - subcache = await cache.ensure_subcache_for(unique_id, new_node_ids) - if subcache.clean_when == "before": - subcache.clean_unused() + await cache.ensure_subcache_for(unique_id, new_node_ids) for node_id in new_output_ids: execution_list.add_node(node_id) execution_list.cache_link(node_id, unique_id) @@ -706,15 +704,6 @@ class PromptExecutor: is_changed = IsChanged(prompt_id, dynamic_prompt, execution_list, extra_data) for cache in self.caches.all: await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed) - if cache.clean_when == "before": - cache.clean_unused() - - if self.caches.outputs.clean_when == "before": - cached_nodes = [] - for node_id in prompt: - if self.caches.outputs.get(node_id) is not None: - cached_nodes.append(node_id) - self.add_message("execution_cached", {"nodes": cached_nodes, "prompt_id": prompt_id}, broadcast=False) comfy.model_management.cleanup_models_gc() pending_subgraph_results = {} @@ -760,8 +749,7 @@ class PromptExecutor: comfy.model_management.unload_all_models() for cache in self.caches.all: - if cache.clean_when == "after": - cache.clean_unused() + cache.clean_unused() async def validate_inputs(prompt_id, prompt, item, validated): unique_id = item diff --git a/tests/execution/test_execution.py b/tests/execution/test_execution.py index f73ca7e3c..10cb4216a 100644 --- a/tests/execution/test_execution.py +++ b/tests/execution/test_execution.py @@ -552,27 +552,47 @@ class TestExecution: assert len(images1) == 1, "Should have 1 image" assert len(images2) == 1, "Should have 1 image" - # This tests that only constant outputs are used in the call to `IS_CHANGED` - def test_is_changed_with_outputs(self, client: ComfyClient, builder: GraphBuilder, server): + def test_is_changed_passed_cached_outputs(self, client: ComfyClient, builder: GraphBuilder, server): g = builder input1 = g.node("StubConstantImage", value=0.5, height=512, width=512, batch_size=1) - test_node = g.node("TestIsChangedWithConstants", image=input1.out(0), value=0.5) - + test_node = g.node("TestIsChangedWithAllInputs", image=input1.out(0), value=0.5) output = g.node("PreviewImage", images=test_node.out(0)) - result = client.run(g) - images = result.get_images(output) + result1 = client.run(g) + images = result1.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" - result = client.run(g) - images = result.get_images(output) + result2 = client.run(g) + images = result2.get_images(output) assert len(images) == 1, "Should have 1 image" assert numpy.array(images[0]).min() == 63 and numpy.array(images[0]).max() == 63, "Image should have value 0.25" + if server["should_cache_results"]: - assert not result.did_run(test_node), "The execution should have been cached" + assert not result2.did_run(test_node), "Test node should not have run again" else: - assert result.did_run(test_node), "The execution should have been re-run" + assert result2.did_run(test_node), "Test node should always run here" + + def test_dont_always_run_downstream(self, client: ComfyClient, builder: GraphBuilder, server): + g = builder + float1 = g.node("TestDontAlwaysRunDownstream", float=0.5) # IS_CHANGED returns float("NaN") + image1 = g.node("StubConstantImage", value=float1.out(0), height=512, width=512, batch_size=1) + output = g.node("PreviewImage", images=image1.out(0)) + + result1 = client.run(g) + images = result1.get_images(output) + assert len(images) == 1, "Should have 1 image" + assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50" + + result2 = client.run(g) + images = result2.get_images(output) + assert len(images) == 1, "Should have 1 image" + assert numpy.array(images[0]).min() == 127 and numpy.array(images[0]).max() == 127, "Image should have value 0.50" + + if server["should_cache_results"]: + assert not result2.did_run(output), "Output node should not have run the second time" + else: + assert result2.did_run(output), "Output node should always run here" def test_parallel_sleep_nodes(self, client: ComfyClient, builder: GraphBuilder, skip_timing_checks): diff --git a/tests/execution/testing_nodes/testing-pack/specific_tests.py b/tests/execution/testing_nodes/testing-pack/specific_tests.py index 4f8f01ae4..5c74905f5 100644 --- a/tests/execution/testing_nodes/testing-pack/specific_tests.py +++ b/tests/execution/testing_nodes/testing-pack/specific_tests.py @@ -100,7 +100,7 @@ class TestCustomIsChanged: else: return False -class TestIsChangedWithConstants: +class TestIsChangedWithAllInputs: @classmethod def INPUT_TYPES(cls): return { @@ -120,10 +120,29 @@ class TestIsChangedWithConstants: @classmethod def IS_CHANGED(cls, image, value): - if image is None: - return value - else: - return image.mean().item() * value + # if image is None then an exception is thrown and is_changed becomes float("NaN") + return image.mean().item() * value + +class TestDontAlwaysRunDownstream: + @classmethod + def INPUT_TYPES(cls): + return { + "required": { + "float": ("FLOAT",), + }, + } + + RETURN_TYPES = ("FLOAT",) + FUNCTION = "always_run" + + CATEGORY = "Testing/Nodes" + + def always_run(self, float): + return (float,) + + @classmethod + def IS_CHANGED(cls, *args, **kwargs): + return float("NaN") class TestCustomValidation1: @classmethod @@ -486,7 +505,8 @@ TEST_NODE_CLASS_MAPPINGS = { "TestLazyMixImages": TestLazyMixImages, "TestVariadicAverage": TestVariadicAverage, "TestCustomIsChanged": TestCustomIsChanged, - "TestIsChangedWithConstants": TestIsChangedWithConstants, + "TestIsChangedWithAllInputs": TestIsChangedWithAllInputs, + "TestDontAlwaysRunDownstream": TestDontAlwaysRunDownstream, "TestCustomValidation1": TestCustomValidation1, "TestCustomValidation2": TestCustomValidation2, "TestCustomValidation3": TestCustomValidation3, @@ -504,7 +524,8 @@ TEST_NODE_DISPLAY_NAME_MAPPINGS = { "TestLazyMixImages": "Lazy Mix Images", "TestVariadicAverage": "Variadic Average", "TestCustomIsChanged": "Custom IsChanged", - "TestIsChangedWithConstants": "IsChanged With Constants", + "TestIsChangedWithAllInputs": "IsChanged With All Inputs", + "TestDontAlwaysRunDownstream": "Dont Always Run Downstream", "TestCustomValidation1": "Custom Validation 1", "TestCustomValidation2": "Custom Validation 2", "TestCustomValidation3": "Custom Validation 3",