Natively support AnyType (*) without __ne__ hacks

This commit is contained in:
Jedrzej Kosinski 2025-12-14 02:41:45 -08:00
parent a4226dbfb0
commit 0d9364e942
2 changed files with 11 additions and 13 deletions

View File

@ -77,16 +77,6 @@ class NumberDisplay(str, Enum):
slider = "slider" slider = "slider"
class _StringIOType(str):
def __ne__(self, value: object) -> bool:
if self == "*" or value == "*":
return False
if not isinstance(value, str):
return True
a = frozenset(self.split(","))
b = frozenset(value.split(","))
return not (b.issubset(a) or a.issubset(b))
class _ComfyType(ABC): class _ComfyType(ABC):
Type = Any Type = Any
io_type: str = None io_type: str = None
@ -126,8 +116,7 @@ def comfytype(io_type: str, **kwargs):
new_cls.__module__ = cls.__module__ new_cls.__module__ = cls.__module__
new_cls.__doc__ = cls.__doc__ new_cls.__doc__ = cls.__doc__
# assign ComfyType attributes, if needed # assign ComfyType attributes, if needed
# NOTE: use __ne__ trick for io_type (see node_typing.IO.__ne__ for details) new_cls.io_type = io_type
new_cls.io_type = _StringIOType(io_type)
if hasattr(new_cls, "Input") and new_cls.Input is not None: if hasattr(new_cls, "Input") and new_cls.Input is not None:
new_cls.Input.Parent = new_cls new_cls.Input.Parent = new_cls
if hasattr(new_cls, "Output") and new_cls.Output is not None: if hasattr(new_cls, "Output") and new_cls.Output is not None:
@ -186,7 +175,7 @@ class Input(_IO_V3):
}) | prune_dict(self.extra_dict) }) | prune_dict(self.extra_dict)
def get_io_type(self): def get_io_type(self):
return _StringIOType(self.io_type) return self.io_type
def get_all(self) -> list[Input]: def get_all(self) -> list[Input]:
return [self] return [self]

View File

@ -21,9 +21,14 @@ def validate_node_input(
""" """
# If the types are exactly the same, we can return immediately # If the types are exactly the same, we can return immediately
# Use pre-union behaviour: inverse of `__ne__` # Use pre-union behaviour: inverse of `__ne__`
# NOTE: this lets legacy '*' Any types work that override the __ne__ method of the str class.
if not received_type != input_type: if not received_type != input_type:
return True return True
# If one of the types is '*', we can return True immediately; this is the 'Any' type.
if received_type == IO.AnyType.io_type or input_type == IO.AnyType.io_type:
return True
# If the received type or input_type is a MatchType, we can return True immediately; # If the received type or input_type is a MatchType, we can return True immediately;
# validation for this is handled by the frontend # validation for this is handled by the frontend
if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type: if received_type == IO.MatchType.io_type or input_type == IO.MatchType.io_type:
@ -42,6 +47,10 @@ def validate_node_input(
received_types = set(t.strip() for t in received_type.split(",")) received_types = set(t.strip() for t in received_type.split(","))
input_types = set(t.strip() for t in input_type.split(",")) input_types = set(t.strip() for t in input_type.split(","))
# If any of the types is '*', we can return True immediately; this is the 'Any' type.
if IO.AnyType.io_type in received_types or IO.AnyType.io_type in input_types:
return True
if strict: if strict:
# In strict mode, all received types must be in the input types # In strict mode, all received types must be in the input types
return received_types.issubset(input_types) return received_types.issubset(input_types)