ComfyUI/tests/inference/testing_nodes/testing-pack/specific_tests.py
Jacob Segal dd3bafb40b Display an error for dependency cycles
Previously, dependency cycles that were created during node expansion
would cause the application to quit (due to an uncaught exception). Now,
we'll throw a proper error to the UI. We also make an attempt to 'blame'
the most relevant node in the UI.
2024-04-20 22:40:38 -07:00

243 lines
7.1 KiB
Python

import torch
from .tools import VariantSupport
from comfy.graph_utils import GraphBuilder
class TestLazyMixImages:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image1": ("IMAGE",{"lazy": True}),
"image2": ("IMAGE",{"lazy": True}),
"mask": ("MASK",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "mix"
CATEGORY = "Testing/Nodes"
def check_lazy_status(self, mask, image1 = None, image2 = None):
mask_min = mask.min()
mask_max = mask.max()
needed = []
if image1 is None and (mask_min != 1.0 or mask_max != 1.0):
needed.append("image1")
if image2 is None and (mask_min != 0.0 or mask_max != 0.0):
needed.append("image2")
return needed
# Not trying to handle different batch sizes here just to keep the demo simple
def mix(self, mask, image1 = None, image2 = None):
mask_min = mask.min()
mask_max = mask.max()
if mask_min == 0.0 and mask_max == 0.0:
return (image1,)
elif mask_min == 1.0 and mask_max == 1.0:
return (image2,)
if len(mask.shape) == 2:
mask = mask.unsqueeze(0)
if len(mask.shape) == 3:
mask = mask.unsqueeze(3)
if mask.shape[3] < image1.shape[3]:
mask = mask.repeat(1, 1, 1, image1.shape[3])
result = image1 * (1. - mask) + image2 * mask,
print(result[0])
return (result[0],)
class TestVariadicAverage:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "variadic_average"
CATEGORY = "Testing/Nodes"
def variadic_average(self, input1, **kwargs):
inputs = [input1]
while 'input' + str(len(inputs) + 1) in kwargs:
inputs.append(kwargs['input' + str(len(inputs) + 1)])
return (torch.stack(inputs).mean(dim=0),)
class TestCustomIsChanged:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"image": ("IMAGE",),
},
"optional": {
"should_change": ("BOOL", {"default": False}),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_is_changed"
CATEGORY = "Testing/Nodes"
def custom_is_changed(self, image, should_change=False):
return (image,)
@classmethod
def IS_CHANGED(cls, should_change=False, *args, **kwargs):
if should_change:
return float("NaN")
else:
return False
class TestCustomValidation1:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation1"
CATEGORY = "Testing/Nodes"
def custom_validation1(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
@classmethod
def VALIDATE_INPUTS(cls, input1=None, input2=None):
if input1 is not None:
if not isinstance(input1, (torch.Tensor, float)):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, (torch.Tensor, float)):
return f"Invalid type of input2: {type(input2)}"
return True
class TestCustomValidation2:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation2"
CATEGORY = "Testing/Nodes"
def custom_validation2(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
@classmethod
def VALIDATE_INPUTS(cls, input_types, input1=None, input2=None):
if input1 is not None:
if not isinstance(input1, (torch.Tensor, float)):
return f"Invalid type of input1: {type(input1)}"
if input2 is not None:
if not isinstance(input2, (torch.Tensor, float)):
return f"Invalid type of input2: {type(input2)}"
if 'input1' in input_types:
if input_types['input1'] not in ["IMAGE", "FLOAT"]:
return f"Invalid type of input1: {input_types['input1']}"
if 'input2' in input_types:
if input_types['input2'] not in ["IMAGE", "FLOAT"]:
return f"Invalid type of input2: {input_types['input2']}"
return True
@VariantSupport()
class TestCustomValidation3:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE,FLOAT",),
"input2": ("IMAGE,FLOAT",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "custom_validation3"
CATEGORY = "Testing/Nodes"
def custom_validation3(self, input1, input2):
if isinstance(input1, float) and isinstance(input2, float):
result = torch.ones([1, 512, 512, 3]) * input1 * input2
else:
result = input1 * input2
return (result,)
class TestDynamicDependencyCycle:
@classmethod
def INPUT_TYPES(cls):
return {
"required": {
"input1": ("IMAGE",),
"input2": ("IMAGE",),
},
}
RETURN_TYPES = ("IMAGE",)
FUNCTION = "dynamic_dependency_cycle"
CATEGORY = "Testing/Nodes"
def dynamic_dependency_cycle(self, input1, input2):
g = GraphBuilder()
mask = g.node("StubMask", value=0.5, height=512, width=512, batch_size=1)
mix1 = g.node("TestLazyMixImages", image1=input1, mask=mask.out(0))
mix2 = g.node("TestLazyMixImages", image1=mix1.out(0), image2=input2, mask=mask.out(0))
# Create the cyle
mix1.set_input("image2", mix2.out(0))
return {
"result": (mix2.out(0),),
"expand": g.finalize(),
}
TEST_NODE_CLASS_MAPPINGS = {
"TestLazyMixImages": TestLazyMixImages,
"TestVariadicAverage": TestVariadicAverage,
"TestCustomIsChanged": TestCustomIsChanged,
"TestCustomValidation1": TestCustomValidation1,
"TestCustomValidation2": TestCustomValidation2,
"TestCustomValidation3": TestCustomValidation3,
"TestDynamicDependencyCycle": TestDynamicDependencyCycle,
}
TEST_NODE_DISPLAY_NAME_MAPPINGS = {
"TestLazyMixImages": "Lazy Mix Images",
"TestVariadicAverage": "Variadic Average",
"TestCustomIsChanged": "Custom IsChanged",
"TestCustomValidation1": "Custom Validation 1",
"TestCustomValidation2": "Custom Validation 2",
"TestCustomValidation3": "Custom Validation 3",
"TestDynamicDependencyCycle": "Dynamic Dependency Cycle",
}