ComfyUI/comfy/nodes/package_typing.py

143 lines
4.2 KiB
Python

from __future__ import annotations
from dataclasses import dataclass, field
from typing_extensions import TypedDict, NotRequired, Generic
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable, List
T = TypeVar('T')
class IntSpecOptions(TypedDict, total=True):
default: int
min: int
max: int
step: NotRequired[int]
display: NotRequired[Literal["number", "slider"]]
class FloatSpecOptions(TypedDict, total=True):
default: float
min: float
max: float
step: NotRequired[float]
round: NotRequired[float]
display: NotRequired[Literal["number", "slider"]]
class StringSpecOptions(TypedDict, total=True):
multiline: NotRequired[bool]
default: NotRequired[str]
# todo: analyze the base_nodes for these types
CommonReturnTypes = Union[
Literal["IMAGE", "STRING", "INT", "FLOAT", "CONDITIONING", "LATENT", "MASK", "MODEL", "VAE", "CLIP"], str]
IntSpec = Tuple[Literal["INT"], IntSpecOptions]
FloatSpec = Tuple[Literal["FLOAT"], FloatSpecOptions]
StringSpec = Tuple[Literal["STRING"], StringSpecOptions]
ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]]
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes]
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, ChoiceSpec, NonPrimitiveTypeSpec]
class HiddenSpec(TypedDict, total=True):
prompt: Literal["PROMPT"]
extra_pnginfo: Literal["EXTRA_PNGINFO"]
class InputTypes(TypedDict, total=True):
required: Dict[str, InputTypeSpec]
optional: NotRequired[Dict[str, InputTypeSpec]]
hidden: NotRequired[HiddenSpec]
ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
class FunctionReturnsUIVariables(TypedDict):
ui: dict
class SaveNodeResult(TypedDict, total=True):
abs_path: NotRequired[str]
filename: str
subfolder: str
type: Literal["output", "input", "temp"]
class UIImagesImagesResult(TypedDict, total=True):
images: List[SaveNodeResult]
class UIImagesResult(TypedDict, total=True):
ui: UIImagesImagesResult
class UILatentsLatentsResult(TypedDict, total=True):
latents: List[SaveNodeResult]
class UILatentsResult(TypedDict, total=True):
ui: UILatentsLatentsResult
ValidatedNodeResult = Union[Tuple, UIImagesResult, UILatentsResult, FunctionReturnsUIVariables]
class CustomNode(Protocol):
@classmethod
def INPUT_TYPES(cls) -> InputTypes: ...
# Optional method signature for VALIDATE_INPUTS
VALIDATE_INPUTS: ClassVar[ValidateInputsMethod] = None
RETURN_TYPES: ClassVar[Tuple[CommonReturnTypes, ...]]
RETURN_NAMES: Optional[ClassVar[Tuple[str, ...]]]
OUTPUT_IS_LIST: Optional[ClassVar[Tuple[bool, ...]]]
INPUT_IS_LIST: Optional[ClassVar[bool]]
FUNCTION: ClassVar[str]
CATEGORY: ClassVar[str]
OUTPUT_NODE: Optional[ClassVar[bool]]
def __call__(self) -> T:
...
@dataclass
class ExportedNodes:
NODE_CLASS_MAPPINGS: Dict[str, CustomNode] = field(default_factory=dict)
NODE_DISPLAY_NAME_MAPPINGS: Dict[str, str] = field(default_factory=dict)
EXTENSION_WEB_DIRS: Dict[str, str] = field(default_factory=dict)
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS)
self.NODE_DISPLAY_NAME_MAPPINGS.update(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS)
self.EXTENSION_WEB_DIRS.update(exported_nodes.EXTENSION_WEB_DIRS)
return self
def __len__(self):
return len(self.NODE_CLASS_MAPPINGS)
def __sub__(self, other: ExportedNodes):
exported_nodes = ExportedNodes().update(self)
for self_key in exported_nodes.NODE_CLASS_MAPPINGS:
if self_key in other.NODE_CLASS_MAPPINGS:
exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key)
if self_key in other.NODE_DISPLAY_NAME_MAPPINGS:
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key)
for self_key in exported_nodes.EXTENSION_WEB_DIRS:
if self_key in other.EXTENSION_WEB_DIRS:
exported_nodes.EXTENSION_WEB_DIRS.pop(self_key)
return exported_nodes
def __add__(self, other):
exported_nodes = ExportedNodes().update(self)
return exported_nodes.update(other)