mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Previously, dependency cycles that were created during node expansion would cause the application to quit (due to an uncaught exception). Now, we'll throw a proper error to the UI. We also make an attempt to 'blame' the most relevant node in the UI.
243 lines
7.1 KiB
Python
243 lines
7.1 KiB
Python
import torch
|
|
from .tools import VariantSupport
|
|
from comfy.graph_utils import GraphBuilder
|
|
|
|
class TestLazyMixImages:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"image1": ("IMAGE",{"lazy": True}),
|
|
"image2": ("IMAGE",{"lazy": True}),
|
|
"mask": ("MASK",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "mix"
|
|
|
|
CATEGORY = "Testing/Nodes"
|
|
|
|
def check_lazy_status(self, mask, image1 = None, image2 = None):
|
|
mask_min = mask.min()
|
|
mask_max = mask.max()
|
|
needed = []
|
|
if image1 is None and (mask_min != 1.0 or mask_max != 1.0):
|
|
needed.append("image1")
|
|
if image2 is None and (mask_min != 0.0 or mask_max != 0.0):
|
|
needed.append("image2")
|
|
return needed
|
|
|
|
# Not trying to handle different batch sizes here just to keep the demo simple
|
|
def mix(self, mask, image1 = None, image2 = None):
|
|
mask_min = mask.min()
|
|
mask_max = mask.max()
|
|
if mask_min == 0.0 and mask_max == 0.0:
|
|
return (image1,)
|
|
elif mask_min == 1.0 and mask_max == 1.0:
|
|
return (image2,)
|
|
|
|
if len(mask.shape) == 2:
|
|
mask = mask.unsqueeze(0)
|
|
if len(mask.shape) == 3:
|
|
mask = mask.unsqueeze(3)
|
|
if mask.shape[3] < image1.shape[3]:
|
|
mask = mask.repeat(1, 1, 1, image1.shape[3])
|
|
|
|
result = image1 * (1. - mask) + image2 * mask,
|
|
print(result[0])
|
|
return (result[0],)
|
|
|
|
class TestVariadicAverage:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"input1": ("IMAGE",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "variadic_average"
|
|
|
|
CATEGORY = "Testing/Nodes"
|
|
|
|
def variadic_average(self, input1, **kwargs):
|
|
inputs = [input1]
|
|
while 'input' + str(len(inputs) + 1) in kwargs:
|
|
inputs.append(kwargs['input' + str(len(inputs) + 1)])
|
|
return (torch.stack(inputs).mean(dim=0),)
|
|
|
|
|
|
class TestCustomIsChanged:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"image": ("IMAGE",),
|
|
},
|
|
"optional": {
|
|
"should_change": ("BOOL", {"default": False}),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "custom_is_changed"
|
|
|
|
CATEGORY = "Testing/Nodes"
|
|
|
|
def custom_is_changed(self, image, should_change=False):
|
|
return (image,)
|
|
|
|
@classmethod
|
|
def IS_CHANGED(cls, should_change=False, *args, **kwargs):
|
|
if should_change:
|
|
return float("NaN")
|
|
else:
|
|
return False
|
|
|
|
class TestCustomValidation1:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"input1": ("IMAGE,FLOAT",),
|
|
"input2": ("IMAGE,FLOAT",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "custom_validation1"
|
|
|
|
CATEGORY = "Testing/Nodes"
|
|
|
|
def custom_validation1(self, input1, input2):
|
|
if isinstance(input1, float) and isinstance(input2, float):
|
|
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
|
else:
|
|
result = input1 * input2
|
|
return (result,)
|
|
|
|
@classmethod
|
|
def VALIDATE_INPUTS(cls, input1=None, input2=None):
|
|
if input1 is not None:
|
|
if not isinstance(input1, (torch.Tensor, float)):
|
|
return f"Invalid type of input1: {type(input1)}"
|
|
if input2 is not None:
|
|
if not isinstance(input2, (torch.Tensor, float)):
|
|
return f"Invalid type of input2: {type(input2)}"
|
|
|
|
return True
|
|
|
|
class TestCustomValidation2:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"input1": ("IMAGE,FLOAT",),
|
|
"input2": ("IMAGE,FLOAT",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "custom_validation2"
|
|
|
|
CATEGORY = "Testing/Nodes"
|
|
|
|
def custom_validation2(self, input1, input2):
|
|
if isinstance(input1, float) and isinstance(input2, float):
|
|
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
|
else:
|
|
result = input1 * input2
|
|
return (result,)
|
|
|
|
@classmethod
|
|
def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None):
|
|
if input1 is not None:
|
|
if not isinstance(input1, (torch.Tensor, float)):
|
|
return f"Invalid type of input1: {type(input1)}"
|
|
if input2 is not None:
|
|
if not isinstance(input2, (torch.Tensor, float)):
|
|
return f"Invalid type of input2: {type(input2)}"
|
|
|
|
if 'input1' in input_types:
|
|
if input_types['input1'] not in ["IMAGE", "FLOAT"]:
|
|
return f"Invalid type of input1: {input_types['input1']}"
|
|
if 'input2' in input_types:
|
|
if input_types['input2'] not in ["IMAGE", "FLOAT"]:
|
|
return f"Invalid type of input2: {input_types['input2']}"
|
|
|
|
return True
|
|
|
|
@VariantSupport()
|
|
class TestCustomValidation3:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"input1": ("IMAGE,FLOAT",),
|
|
"input2": ("IMAGE,FLOAT",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "custom_validation3"
|
|
|
|
CATEGORY = "Testing/Nodes"
|
|
|
|
def custom_validation3(self, input1, input2):
|
|
if isinstance(input1, float) and isinstance(input2, float):
|
|
result = torch.ones([1, 512, 512, 3]) * input1 * input2
|
|
else:
|
|
result = input1 * input2
|
|
return (result,)
|
|
|
|
class TestDynamicDependencyCycle:
|
|
@classmethod
|
|
def INPUT_TYPES(cls):
|
|
return {
|
|
"required": {
|
|
"input1": ("IMAGE",),
|
|
"input2": ("IMAGE",),
|
|
},
|
|
}
|
|
|
|
RETURN_TYPES = ("IMAGE",)
|
|
FUNCTION = "dynamic_dependency_cycle"
|
|
|
|
CATEGORY = "Testing/Nodes"
|
|
|
|
def dynamic_dependency_cycle(self, input1, input2):
|
|
g = GraphBuilder()
|
|
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
|
|
mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0))
|
|
mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0))
|
|
|
|
# Create the cyle
|
|
mix1.set_input("image2", mix2.out(0))
|
|
|
|
return {
|
|
"result": (mix2.out(0),),
|
|
"expand": g.finalize(),
|
|
}
|
|
|
|
TEST_NODE_CLASS_MAPPINGS = {
|
|
"TestLazyMixImages": TestLazyMixImages,
|
|
"TestVariadicAverage": TestVariadicAverage,
|
|
"TestCustomIsChanged": TestCustomIsChanged,
|
|
"TestCustomValidation1": TestCustomValidation1,
|
|
"TestCustomValidation2": TestCustomValidation2,
|
|
"TestCustomValidation3": TestCustomValidation3,
|
|
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
|
|
}
|
|
|
|
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
|
|
"TestLazyMixImages": "Lazy Mix Images",
|
|
"TestVariadicAverage": "Variadic Average",
|
|
"TestCustomIsChanged": "Custom IsChanged",
|
|
"TestCustomValidation1": "Custom Validation 1",
|
|
"TestCustomValidation2": "Custom Validation 2",
|
|
"TestCustomValidation3": "Custom Validation 3",
|
|
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
|
|
}
|