From 57e8bf6a9f9f550613d00ce36ac60841145a2b7b Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Mon, 2 Dec 2024 19:49:49 -0500 Subject: [PATCH 1/4] Fix case where a memory leak could cause crash. Now the only symptom of code messing up and keeping references to a model object when it should not will be endless prints in the log instead of the next workflow crashing ComfyUI. --- comfy/model_management.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/comfy/model_management.py b/comfy/model_management.py index 2a5009565..a16cf380a 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -369,6 +369,9 @@ class LoadedModel: if self._patcher_finalizer is not None: self._patcher_finalizer.detach() + def is_dead(self): + return self.real_model() is not None and self.model is None + def use_more_memory(extra_memory, loaded_models, device): for m in loaded_models: @@ -409,7 +412,7 @@ def free_memory(memory_required, device, keep_loaded=[]): for i in range(len(current_loaded_models) -1, -1, -1): shift_model = current_loaded_models[i] if shift_model.device == device: - if shift_model not in keep_loaded: + if shift_model not in keep_loaded and not shift_model.is_dead(): can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) shift_model.currently_used = False @@ -531,7 +534,7 @@ def cleanup_models_gc(): do_gc = False for i in range(len(current_loaded_models)): cur = current_loaded_models[i] - if cur.real_model() is not None and cur.model is None: + if cur.is_dead(): logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__)) do_gc = True break @@ -542,7 +545,7 @@ def cleanup_models_gc(): for i in range(len(current_loaded_models)): cur = current_loaded_models[i] - if cur.real_model() is not None and cur.model is None: + if cur.is_dead(): logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__)) From 8d4e06324fb6c477c1f7f409c857c33d0b3b0ce2 Mon Sep 17 00:00:00 2001 From: Chenlei Hu Date: Tue, 3 Dec 2024 02:46:00 -0800 Subject: [PATCH 2/4] Add union link connection type support (#5806) * Add union type support * Move code * nit --- comfy_execution/validation.py | 32 ++++++++ execution.py | 6 +- .../validate_node_input_test.py | 75 +++++++++++++++++++ 3 files changed, 110 insertions(+), 3 deletions(-) create mode 100644 comfy_execution/validation.py create mode 100644 tests-unit/execution_test/validate_node_input_test.py diff --git a/comfy_execution/validation.py b/comfy_execution/validation.py new file mode 100644 index 000000000..43fb6426d --- /dev/null +++ b/comfy_execution/validation.py @@ -0,0 +1,32 @@ +from __future__ import annotations + + +def validate_node_input( + received_type: str, input_type: str, strict: bool = False +) -> bool: + """ + received_type and input_type are both strings of the form "T1,T2,...". + + If strict is True, the input_type must contain the received_type. + For example, if received_type is "STRING" and input_type is "STRING,INT", + this will return True. But if received_type is "STRING,INT" and input_type is + "INT", this will return False. + + If strict is False, the input_type must have overlap with the received_type. + For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT", + this will return True. + """ + # If the types are exactly the same, we can return immediately + if received_type == input_type: + return True + + # Split the type strings into sets for comparison + received_types = set(t.strip() for t in received_type.split(",")) + input_types = set(t.strip() for t in input_type.split(",")) + + if strict: + # In strict mode, all received types must be in the input types + return received_types.issubset(input_types) + else: + # In non-strict mode, there must be at least one type in common + return len(received_types.intersection(input_types)) > 0 diff --git a/execution.py b/execution.py index 768e35abc..929ef85fa 100644 --- a/execution.py +++ b/execution.py @@ -16,6 +16,7 @@ import comfy.model_management from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker from comfy_execution.graph_utils import is_link, GraphBuilder from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID +from comfy_execution.validation import validate_node_input from comfy.cli_args import args class ExecutionResult(Enum): @@ -527,7 +528,6 @@ class PromptExecutor: comfy.model_management.unload_all_models() - def validate_inputs(prompt, item, validated): unique_id = item if unique_id in validated: @@ -589,8 +589,8 @@ def validate_inputs(prompt, item, validated): r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES received_type = r[val[1]] received_types[x] = received_type - if 'input_types' not in validate_function_inputs and received_type != type_input: - details = f"{x}, {received_type} != {type_input}" + if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input): + details = f"{x}, received_type({received_type}) mismatch input_type({type_input})" error = { "type": "return_type_mismatch", "message": "Return type mismatch between linked nodes", diff --git a/tests-unit/execution_test/validate_node_input_test.py b/tests-unit/execution_test/validate_node_input_test.py new file mode 100644 index 000000000..d6605e97f --- /dev/null +++ b/tests-unit/execution_test/validate_node_input_test.py @@ -0,0 +1,75 @@ +import pytest +from comfy_execution.validation import validate_node_input + + +def test_exact_match(): + """Test cases where types match exactly""" + assert validate_node_input("STRING", "STRING") + assert validate_node_input("STRING,INT", "STRING,INT") + assert ( + validate_node_input("INT,STRING", "STRING,INT") + ) # Order shouldn't matter + + +def test_strict_mode(): + """Test strict mode validation""" + # Should pass - received type is subset of input type + assert validate_node_input("STRING", "STRING,INT", strict=True) + assert validate_node_input("INT", "STRING,INT", strict=True) + assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True) + + # Should fail - received type is not subset of input type + assert not validate_node_input("STRING,INT", "STRING", strict=True) + assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True) + assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True) + + +def test_non_strict_mode(): + """Test non-strict mode validation (default behavior)""" + # Should pass - types have overlap + assert validate_node_input("STRING,BOOLEAN", "STRING,INT") + assert validate_node_input("STRING,INT", "INT,BOOLEAN") + assert validate_node_input("STRING", "STRING,INT") + + # Should fail - no overlap in types + assert not validate_node_input("BOOLEAN", "STRING,INT") + assert not validate_node_input("FLOAT", "STRING,INT") + assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT") + + +def test_whitespace_handling(): + """Test that whitespace is handled correctly""" + assert validate_node_input("STRING, INT", "STRING,INT") + assert validate_node_input("STRING,INT", "STRING, INT") + assert validate_node_input(" STRING , INT ", "STRING,INT") + assert validate_node_input("STRING,INT", " STRING , INT ") + + +def test_empty_strings(): + """Test behavior with empty strings""" + assert validate_node_input("", "") + assert not validate_node_input("STRING", "") + assert not validate_node_input("", "STRING") + + +def test_single_vs_multiple(): + """Test single type against multiple types""" + assert validate_node_input("STRING", "STRING,INT,BOOLEAN") + assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False) + assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True) + + +@pytest.mark.parametrize( + "received,input_type,strict,expected", + [ + ("STRING", "STRING", False, True), + ("STRING,INT", "STRING,INT", False, True), + ("STRING", "STRING,INT", True, True), + ("STRING,INT", "STRING", True, False), + ("BOOLEAN", "STRING,INT", False, False), + ("STRING,BOOLEAN", "STRING,INT", False, True), + ], +) +def test_parametrized_cases(received, input_type, strict, expected): + """Parametrized test cases for various scenarios""" + assert validate_node_input(received, input_type, strict) == expected From cdc3b97dd5b8f3f7d4915ef9884e0432057cb1a0 Mon Sep 17 00:00:00 2001 From: Alexander Piskun <13381981+bigcat88@users.noreply.github.com> Date: Tue, 3 Dec 2024 14:02:01 +0300 Subject: [PATCH 3/4] resolve relative paths in YAML configuration for extra model paths (#5847) Signed-off-by: bigcat88 --- utils/extra_config.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/utils/extra_config.py b/utils/extra_config.py index 908765902..415db0427 100644 --- a/utils/extra_config.py +++ b/utils/extra_config.py @@ -24,5 +24,8 @@ def load_extra_path_config(yaml_path): full_path = y if base_path is not None: full_path = os.path.join(base_path, full_path) + elif not os.path.isabs(full_path): + yaml_dir = os.path.dirname(os.path.abspath(yaml_path)) + full_path = os.path.abspath(os.path.join(yaml_dir, y)) logging.info("Adding extra search path {} {}".format(x, full_path)) folder_paths.add_model_folder_path(x, full_path, is_default) From c1b92b719d3e09bc859078306256f4d3496e1eca Mon Sep 17 00:00:00 2001 From: comfyanonymous Date: Tue, 3 Dec 2024 06:11:52 -0500 Subject: [PATCH 4/4] Some optimizations to euler a. --- comfy/k_diffusion/sampling.py | 37 ++++++++++++++++++++--------------- 1 file changed, 21 insertions(+), 16 deletions(-) diff --git a/comfy/k_diffusion/sampling.py b/comfy/k_diffusion/sampling.py index 2838b50c7..f08370f83 100644 --- a/comfy/k_diffusion/sampling.py +++ b/comfy/k_diffusion/sampling.py @@ -175,12 +175,14 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - d = to_d(x, sigmas[i], denoised) - # Euler method - dt = sigma_down - sigmas[i] - x = x + d * dt - if sigmas[i + 1] > 0: - x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up + + if sigma_down == 0: + x = denoised + else: + d = to_d(x, sigmas[i], denoised) + # Euler method + dt = sigma_down - sigmas[i] + x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up return x @torch.no_grad() @@ -192,19 +194,22 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None, for i in trange(len(sigmas) - 1, disable=disable): denoised = model(x, sigmas[i] * s_in, **extra_args) # sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) - downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta - sigma_down = sigmas[i+1] * downstep_ratio - alpha_ip1 = 1 - sigmas[i+1] - alpha_down = 1 - sigma_down - renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5 if callback is not None: callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised}) - # Euler method - sigma_down_i_ratio = sigma_down / sigmas[i] - x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised - if sigmas[i + 1] > 0 and eta > 0: - x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff + if sigmas[i + 1] == 0: + x = denoised + else: + downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta + sigma_down = sigmas[i + 1] * downstep_ratio + alpha_ip1 = 1 - sigmas[i + 1] + alpha_down = 1 - sigma_down + renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5 + # Euler method + sigma_down_i_ratio = sigma_down / sigmas[i] + x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised + if eta > 0: + x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff return x @torch.no_grad()