ComfyUI/comfy/nodes/package_typing.py
doctorpangloss 31b6b53236 Quality of life improvements
- export_custom_nodes() finds all the classes that inherit from
   CustomNode and exports them correctly for custom node discovery to
   find
 - regular expressions
 - additional string formatting and parsing nodes
2025-02-12 14:12:10 -08:00

190 lines
5.9 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable, List, Type, runtime_checkable
from typing_extensions import TypedDict, NotRequired
T = TypeVar('T')
class IntSpecOptions(TypedDict, total=True):
default: int
min: int
max: int
step: NotRequired[int]
display: NotRequired[Literal["number", "slider"]]
lazy: NotRequired[bool]
class FloatSpecOptions(TypedDict, total=True):
default: float
min: float
max: float
step: NotRequired[float]
round: NotRequired[float]
display: NotRequired[Literal["number", "slider"]]
lazy: NotRequired[bool]
class StringSpecOptions(TypedDict, total=True):
multiline: NotRequired[bool]
default: NotRequired[str]
dynamicPrompts: NotRequired[bool]
lazy: NotRequired[bool]
class BoolSpecOptions(TypedDict):
default: NotRequired[bool]
lazy: NotRequired[bool]
class DefaultSpecOptions(TypedDict):
default: NotRequired[Any]
lazy: NotRequired[bool]
# todo: analyze the base_nodes for these types
CommonReturnTypes = Union[
Literal["IMAGE", "STRING", "INT", "BOOLEAN", "FLOAT", "CONDITIONING", "LATENT", "MASK", "MODEL", "VAE", "CLIP"], str, List]
IntSpec = Tuple[Literal["INT"], IntSpecOptions]
FloatSpec = Tuple[Literal["FLOAT"], FloatSpecOptions]
StringSpec = Tuple[Literal["STRING"], StringSpecOptions]
BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions]
ChoiceSpec = Tuple[Union[List[str], List[float], List[int], Tuple[str, ...], Tuple[float, ...], Tuple[int, ...]]]
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
# numpy seeds must be between 0 and 2**32 - 1
Seed = ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1})
Seed64 = ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff})
SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})]
class HiddenSpec(TypedDict, total=True):
prompt: Literal["PROMPT"]
extra_pnginfo: Literal["EXTRA_PNGINFO"]
class InputTypes(TypedDict, total=True):
required: Dict[str, InputTypeSpec]
optional: NotRequired[Dict[str, InputTypeSpec]]
hidden: NotRequired[HiddenSpec]
ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
IsChangedMethod = Callable[[Type[Any], ...], str]
class FunctionReturnsUIVariables(TypedDict):
ui: dict
result: NotRequired[Sequence[Any]]
class SaveNodeResult(TypedDict, total=True):
abs_path: NotRequired[str]
filename: str
subfolder: str
type: Literal["output", "input", "temp"]
class UIImagesImagesResult(TypedDict, total=True):
images: List[SaveNodeResult]
class UIImagesResult(TypedDict, total=True):
ui: UIImagesImagesResult
result: NotRequired[Sequence[Any]]
class UILatentsLatentsResult(TypedDict, total=True):
latents: List[SaveNodeResult]
class UILatentsResult(TypedDict, total=True):
ui: UILatentsLatentsResult
result: NotRequired[Sequence[Any]]
ValidatedNodeResult = Union[Tuple, UIImagesResult, UILatentsResult, FunctionReturnsUIVariables]
class CustomNode(Protocol):
@classmethod
def INPUT_TYPES(cls) -> InputTypes: ...
# Optional method signature for VALIDATE_INPUTS
VALIDATE_INPUTS: Optional[ClassVar[ValidateInputsMethod]]
RETURN_TYPES: ClassVar[Tuple[CommonReturnTypes, ...]]
RETURN_NAMES: Optional[ClassVar[Tuple[str, ...]]]
OUTPUT_IS_LIST: Optional[ClassVar[Tuple[bool, ...]]]
INPUT_IS_LIST: Optional[ClassVar[bool]]
FUNCTION: ClassVar[str]
CATEGORY: ClassVar[str]
OUTPUT_NODE: Optional[ClassVar[bool]]
INFERENCE_MODE: Optional[ClassVar[bool]]
@classmethod
def IS_CHANGED(cls, *args, **kwargs) -> str:
...
@classmethod
def __call__(cls, *args, **kwargs) -> 'CustomNode':
...
def check_lazy_status(self, *args, **kwargs) -> list[str]:
"""
Return a list of input names that need to be evaluated.
This function will be called if there are any lazy inputs which have not yet been
evaluated. As long as you return at least one field which has not yet been evaluated
(and more exist), this function will be called again once the value of the requested
field is available.
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
inputs will have the value None.
"""
...
@dataclass
class ExportedNodes:
NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict)
NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict)
EXTENSION_WEB_DIRS: Dict[str, str] = field(default_factory=dict)
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS)
self.NODE_DISPLAY_NAME_MAPPINGS.update(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS)
self.EXTENSION_WEB_DIRS.update(exported_nodes.EXTENSION_WEB_DIRS)
return self
def __len__(self):
return len(self.NODE_CLASS_MAPPINGS)
def __sub__(self, other: ExportedNodes):
exported_nodes = ExportedNodes().update(self)
for self_key in frozenset(exported_nodes.NODE_CLASS_MAPPINGS):
if self_key in other.NODE_CLASS_MAPPINGS:
exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key)
if self_key in other.NODE_DISPLAY_NAME_MAPPINGS:
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key)
for self_key in frozenset(exported_nodes.EXTENSION_WEB_DIRS):
if self_key in other.EXTENSION_WEB_DIRS:
exported_nodes.EXTENSION_WEB_DIRS.pop(self_key)
return exported_nodes
def __add__(self, other):
exported_nodes = ExportedNodes().update(self)
return exported_nodes.update(other)