ComfyUI/tests/inference/testing_nodes/testing-pack/tools.py
Jacob Segal 6d09dd70f8 Make custom VALIDATE_INPUTS skip normal validation
Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`,
that variable will be a dictionary of the socket type of all incoming
connections. If that argument exists, normal socket type validation will
not occur. This removes the last hurdle for enabling variant types
entirely from custom nodes, so I've removed that command-line option.

I've added appropriate unit tests for these changes.
2024-02-24 23:17:01 -08:00

49 lines
2.2 KiB
Python

class SmartType(str):
def __ne__(self, other):
if self == "*" or other == "*":
return False
selfset = set(self.split(','))
otherset = set(other.split(','))
return not selfset.issubset(otherset)
def VariantSupport():
def decorator(cls):
if hasattr(cls, "INPUT_TYPES"):
old_input_types = getattr(cls, "INPUT_TYPES")
def new_input_types(*args, **kwargs):
types = old_input_types(*args, **kwargs)
for category in ["required", "optional"]:
if category not in types:
continue
for key, value in types[category].items():
if isinstance(value, tuple):
types[category][key] = (SmartType(value[0]),) + value[1:]
return types
setattr(cls, "INPUT_TYPES", new_input_types)
if hasattr(cls, "RETURN_TYPES"):
old_return_types = cls.RETURN_TYPES
setattr(cls, "RETURN_TYPES", tuple(SmartType(x) for x in old_return_types))
if hasattr(cls, "VALIDATE_INPUTS"):
# Reflection is used to determine what the function signature is, so we can't just change the function signature
raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet")
else:
def validate_inputs(input_types):
inputs = cls.INPUT_TYPES()
for key, value in input_types.items():
if isinstance(value, SmartType):
continue
if "required" in inputs and key in inputs["required"]:
expected_type = inputs["required"][key][0]
elif "optional" in inputs and key in inputs["optional"]:
expected_type = inputs["optional"][key][0]
else:
expected_type = None
if expected_type is not None and SmartType(value) != expected_type:
return f"Invalid type of {key}: {value} (expected {expected_type})"
return True
setattr(cls, "VALIDATE_INPUTS", validate_inputs)
return cls
return decorator