mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-03 02:00:29 +08:00
Merge branch 'comfyanonymous:master' into master
This commit is contained in:
commit
b826d3e8c2
@ -175,12 +175,14 @@ def sample_euler_ancestral(model, x, sigmas, extra_args=None, callback=None, dis
|
||||
sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt
|
||||
if sigmas[i + 1] > 0:
|
||||
x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
|
||||
if sigma_down == 0:
|
||||
x = denoised
|
||||
else:
|
||||
d = to_d(x, sigmas[i], denoised)
|
||||
# Euler method
|
||||
dt = sigma_down - sigmas[i]
|
||||
x = x + d * dt + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
@ -192,19 +194,22 @@ def sample_euler_ancestral_RF(model, x, sigmas, extra_args=None, callback=None,
|
||||
for i in trange(len(sigmas) - 1, disable=disable):
|
||||
denoised = model(x, sigmas[i] * s_in, **extra_args)
|
||||
# sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta)
|
||||
downstep_ratio = 1 + (sigmas[i+1]/sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i+1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i+1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i+1]**2 - sigma_down**2*alpha_ip1**2/alpha_down**2)**0.5
|
||||
if callback is not None:
|
||||
callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
|
||||
|
||||
# Euler method
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||
if sigmas[i + 1] > 0 and eta > 0:
|
||||
x = (alpha_ip1/alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
if sigmas[i + 1] == 0:
|
||||
x = denoised
|
||||
else:
|
||||
downstep_ratio = 1 + (sigmas[i + 1] / sigmas[i] - 1) * eta
|
||||
sigma_down = sigmas[i + 1] * downstep_ratio
|
||||
alpha_ip1 = 1 - sigmas[i + 1]
|
||||
alpha_down = 1 - sigma_down
|
||||
renoise_coeff = (sigmas[i + 1]**2 - sigma_down**2 * alpha_ip1**2 / alpha_down**2)**0.5
|
||||
# Euler method
|
||||
sigma_down_i_ratio = sigma_down / sigmas[i]
|
||||
x = sigma_down_i_ratio * x + (1 - sigma_down_i_ratio) * denoised
|
||||
if eta > 0:
|
||||
x = (alpha_ip1 / alpha_down) * x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * renoise_coeff
|
||||
return x
|
||||
|
||||
@torch.no_grad()
|
||||
|
||||
@ -382,6 +382,9 @@ class LoadedModel:
|
||||
if self._patcher_finalizer is not None:
|
||||
self._patcher_finalizer.detach()
|
||||
|
||||
def is_dead(self):
|
||||
return self.real_model() is not None and self.model is None
|
||||
|
||||
|
||||
def use_more_memory(extra_memory, loaded_models, device):
|
||||
for m in loaded_models:
|
||||
@ -422,7 +425,7 @@ def free_memory(memory_required, device, keep_loaded=[]):
|
||||
for i in range(len(current_loaded_models) -1, -1, -1):
|
||||
shift_model = current_loaded_models[i]
|
||||
if shift_model.device == device:
|
||||
if shift_model not in keep_loaded:
|
||||
if shift_model not in keep_loaded and not shift_model.is_dead():
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
@ -544,7 +547,7 @@ def cleanup_models_gc():
|
||||
do_gc = False
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.real_model() is not None and cur.model is None:
|
||||
if cur.is_dead():
|
||||
logging.info("Potential memory leak detected with model {}, doing a full garbage collect, for maximum performance avoid circular references in the model code.".format(cur.real_model().__class__.__name__))
|
||||
do_gc = True
|
||||
break
|
||||
@ -555,7 +558,7 @@ def cleanup_models_gc():
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.real_model() is not None and cur.model is None:
|
||||
if cur.is_dead():
|
||||
logging.warning("WARNING, memory leak with model {}. Please make sure it is not being referenced from somewhere.".format(cur.real_model().__class__.__name__))
|
||||
|
||||
|
||||
|
||||
32
comfy_execution/validation.py
Normal file
32
comfy_execution/validation.py
Normal file
@ -0,0 +1,32 @@
|
||||
from __future__ import annotations
|
||||
|
||||
|
||||
def validate_node_input(
|
||||
received_type: str, input_type: str, strict: bool = False
|
||||
) -> bool:
|
||||
"""
|
||||
received_type and input_type are both strings of the form "T1,T2,...".
|
||||
|
||||
If strict is True, the input_type must contain the received_type.
|
||||
For example, if received_type is "STRING" and input_type is "STRING,INT",
|
||||
this will return True. But if received_type is "STRING,INT" and input_type is
|
||||
"INT", this will return False.
|
||||
|
||||
If strict is False, the input_type must have overlap with the received_type.
|
||||
For example, if received_type is "STRING,BOOLEAN" and input_type is "STRING,INT",
|
||||
this will return True.
|
||||
"""
|
||||
# If the types are exactly the same, we can return immediately
|
||||
if received_type == input_type:
|
||||
return True
|
||||
|
||||
# Split the type strings into sets for comparison
|
||||
received_types = set(t.strip() for t in received_type.split(","))
|
||||
input_types = set(t.strip() for t in input_type.split(","))
|
||||
|
||||
if strict:
|
||||
# In strict mode, all received types must be in the input types
|
||||
return received_types.issubset(input_types)
|
||||
else:
|
||||
# In non-strict mode, there must be at least one type in common
|
||||
return len(received_types.intersection(input_types)) > 0
|
||||
@ -16,6 +16,7 @@ import comfy.model_management
|
||||
from comfy_execution.graph import get_input_info, ExecutionList, DynamicPrompt, ExecutionBlocker
|
||||
from comfy_execution.graph_utils import is_link, GraphBuilder
|
||||
from comfy_execution.caching import HierarchicalCache, LRUCache, CacheKeySetInputSignature, CacheKeySetID
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy.cli_args import args
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@ -527,7 +528,6 @@ class PromptExecutor:
|
||||
comfy.model_management.unload_all_models()
|
||||
|
||||
|
||||
|
||||
def validate_inputs(prompt, item, validated):
|
||||
unique_id = item
|
||||
if unique_id in validated:
|
||||
@ -589,8 +589,8 @@ def validate_inputs(prompt, item, validated):
|
||||
r = nodes.NODE_CLASS_MAPPINGS[o_class_type].RETURN_TYPES
|
||||
received_type = r[val[1]]
|
||||
received_types[x] = received_type
|
||||
if 'input_types' not in validate_function_inputs and received_type != type_input:
|
||||
details = f"{x}, {received_type} != {type_input}"
|
||||
if 'input_types' not in validate_function_inputs and not validate_node_input(received_type, type_input):
|
||||
details = f"{x}, received_type({received_type}) mismatch input_type({type_input})"
|
||||
error = {
|
||||
"type": "return_type_mismatch",
|
||||
"message": "Return type mismatch between linked nodes",
|
||||
|
||||
75
tests-unit/execution_test/validate_node_input_test.py
Normal file
75
tests-unit/execution_test/validate_node_input_test.py
Normal file
@ -0,0 +1,75 @@
|
||||
import pytest
|
||||
from comfy_execution.validation import validate_node_input
|
||||
|
||||
|
||||
def test_exact_match():
|
||||
"""Test cases where types match exactly"""
|
||||
assert validate_node_input("STRING", "STRING")
|
||||
assert validate_node_input("STRING,INT", "STRING,INT")
|
||||
assert (
|
||||
validate_node_input("INT,STRING", "STRING,INT")
|
||||
) # Order shouldn't matter
|
||||
|
||||
|
||||
def test_strict_mode():
|
||||
"""Test strict mode validation"""
|
||||
# Should pass - received type is subset of input type
|
||||
assert validate_node_input("STRING", "STRING,INT", strict=True)
|
||||
assert validate_node_input("INT", "STRING,INT", strict=True)
|
||||
assert validate_node_input("STRING,INT", "STRING,INT,BOOLEAN", strict=True)
|
||||
|
||||
# Should fail - received type is not subset of input type
|
||||
assert not validate_node_input("STRING,INT", "STRING", strict=True)
|
||||
assert not validate_node_input("STRING,BOOLEAN", "STRING", strict=True)
|
||||
assert not validate_node_input("INT,BOOLEAN", "STRING,INT", strict=True)
|
||||
|
||||
|
||||
def test_non_strict_mode():
|
||||
"""Test non-strict mode validation (default behavior)"""
|
||||
# Should pass - types have overlap
|
||||
assert validate_node_input("STRING,BOOLEAN", "STRING,INT")
|
||||
assert validate_node_input("STRING,INT", "INT,BOOLEAN")
|
||||
assert validate_node_input("STRING", "STRING,INT")
|
||||
|
||||
# Should fail - no overlap in types
|
||||
assert not validate_node_input("BOOLEAN", "STRING,INT")
|
||||
assert not validate_node_input("FLOAT", "STRING,INT")
|
||||
assert not validate_node_input("FLOAT,BOOLEAN", "STRING,INT")
|
||||
|
||||
|
||||
def test_whitespace_handling():
|
||||
"""Test that whitespace is handled correctly"""
|
||||
assert validate_node_input("STRING, INT", "STRING,INT")
|
||||
assert validate_node_input("STRING,INT", "STRING, INT")
|
||||
assert validate_node_input(" STRING , INT ", "STRING,INT")
|
||||
assert validate_node_input("STRING,INT", " STRING , INT ")
|
||||
|
||||
|
||||
def test_empty_strings():
|
||||
"""Test behavior with empty strings"""
|
||||
assert validate_node_input("", "")
|
||||
assert not validate_node_input("STRING", "")
|
||||
assert not validate_node_input("", "STRING")
|
||||
|
||||
|
||||
def test_single_vs_multiple():
|
||||
"""Test single type against multiple types"""
|
||||
assert validate_node_input("STRING", "STRING,INT,BOOLEAN")
|
||||
assert validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=False)
|
||||
assert not validate_node_input("STRING,INT,BOOLEAN", "STRING", strict=True)
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"received,input_type,strict,expected",
|
||||
[
|
||||
("STRING", "STRING", False, True),
|
||||
("STRING,INT", "STRING,INT", False, True),
|
||||
("STRING", "STRING,INT", True, True),
|
||||
("STRING,INT", "STRING", True, False),
|
||||
("BOOLEAN", "STRING,INT", False, False),
|
||||
("STRING,BOOLEAN", "STRING,INT", False, True),
|
||||
],
|
||||
)
|
||||
def test_parametrized_cases(received, input_type, strict, expected):
|
||||
"""Parametrized test cases for various scenarios"""
|
||||
assert validate_node_input(received, input_type, strict) == expected
|
||||
@ -24,5 +24,8 @@ def load_extra_path_config(yaml_path):
|
||||
full_path = y
|
||||
if base_path is not None:
|
||||
full_path = os.path.join(base_path, full_path)
|
||||
elif not os.path.isabs(full_path):
|
||||
yaml_dir = os.path.dirname(os.path.abspath(yaml_path))
|
||||
full_path = os.path.abspath(os.path.join(yaml_dir, y))
|
||||
logging.info("Adding extra search path {} {}".format(x, full_path))
|
||||
folder_paths.add_model_folder_path(x, full_path, is_default)
|
||||
|
||||
Loading…
Reference in New Issue
Block a user