mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Additionally, if `VALIDATE_INPUTS` takes an argument named `input_types`, that variable will be a dictionary of the socket type of all incoming connections. If that argument exists, normal socket type validation will not occur. This removes the last hurdle for enabling variant types entirely from custom nodes, so I've removed that command-line option. I've added appropriate unit tests for these changes.
49 lines
2.2 KiB
Python
49 lines
2.2 KiB
Python
|
|
class SmartType(str):
|
|
def __ne__(self, other):
|
|
if self == "*" or other == "*":
|
|
return False
|
|
selfset = set(self.split(','))
|
|
otherset = set(other.split(','))
|
|
return not selfset.issubset(otherset)
|
|
|
|
def VariantSupport():
|
|
def decorator(cls):
|
|
if hasattr(cls, "INPUT_TYPES"):
|
|
old_input_types = getattr(cls, "INPUT_TYPES")
|
|
def new_input_types(*args, **kwargs):
|
|
types = old_input_types(*args, **kwargs)
|
|
for category in ["required", "optional"]:
|
|
if category not in types:
|
|
continue
|
|
for key, value in types[category].items():
|
|
if isinstance(value, tuple):
|
|
types[category][key] = (SmartType(value[0]),) + value[1:]
|
|
return types
|
|
setattr(cls, "INPUT_TYPES", new_input_types)
|
|
if hasattr(cls, "RETURN_TYPES"):
|
|
old_return_types = cls.RETURN_TYPES
|
|
setattr(cls, "RETURN_TYPES", tuple(SmartType(x) for x in old_return_types))
|
|
if hasattr(cls, "VALIDATE_INPUTS"):
|
|
# Reflection is used to determine what the function signature is, so we can't just change the function signature
|
|
raise NotImplementedError("VariantSupport does not support VALIDATE_INPUTS yet")
|
|
else:
|
|
def validate_inputs(input_types):
|
|
inputs = cls.INPUT_TYPES()
|
|
for key, value in input_types.items():
|
|
if isinstance(value, SmartType):
|
|
continue
|
|
if "required" in inputs and key in inputs["required"]:
|
|
expected_type = inputs["required"][key][0]
|
|
elif "optional" in inputs and key in inputs["optional"]:
|
|
expected_type = inputs["optional"][key][0]
|
|
else:
|
|
expected_type = None
|
|
if expected_type is not None and SmartType(value) != expected_type:
|
|
return f"Invalid type of {key}: {value} (expected {expected_type})"
|
|
return True
|
|
setattr(cls, "VALIDATE_INPUTS", validate_inputs)
|
|
return cls
|
|
return decorator
|
|
|