ComfyUI/comfy/nodes/package_typing.py
doctorpangloss 96b4e04315 packaging fixes
- enable user db
 - fix main_pre order everywhere
 - fix absolute to relative imports everywhere
 - async better supported
2025-07-15 10:19:33 -07:00

218 lines
7.3 KiB
Python

from __future__ import annotations
from collections import ChainMap
from dataclasses import dataclass, field
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable, List, Type, MutableMapping
from typing_extensions import TypedDict, NotRequired
T = TypeVar('T')
class IntSpecOptions(TypedDict, total=True):
default: int
min: int
max: int
step: NotRequired[int]
display: NotRequired[Literal["number", "slider"]]
lazy: NotRequired[bool]
control_after_generate: NotRequired[bool]
class FloatSpecOptions(TypedDict, total=True):
default: float
min: float
max: float
step: NotRequired[float]
round: NotRequired[float]
display: NotRequired[Literal["number", "slider"]]
lazy: NotRequired[bool]
class StringSpecOptions(TypedDict, total=True):
multiline: NotRequired[bool]
default: NotRequired[str]
dynamicPrompts: NotRequired[bool]
lazy: NotRequired[bool]
class BoolSpecOptions(TypedDict):
default: NotRequired[bool]
lazy: NotRequired[bool]
class DefaultSpecOptions(TypedDict):
default: NotRequired[Any]
lazy: NotRequired[bool]
# todo: analyze the base_nodes for these types
CommonReturnTypes = Union[
Literal["IMAGE", "STRING", "INT", "BOOLEAN", "FLOAT", "CONDITIONING", "LATENT", "MASK", "MODEL", "VAE", "CLIP"], str, List]
IntSpec = Tuple[Literal["INT"], IntSpecOptions]
FloatSpec = Tuple[Literal["FLOAT"], FloatSpecOptions]
StringSpec = Tuple[Literal["STRING"], StringSpecOptions]
BooleanSpec = Tuple[Literal["BOOLEAN"], BoolSpecOptions]
ChoiceSpec = Tuple[Union[List[str], List[float], List[int], Tuple[str, ...], Tuple[float, ...], Tuple[int, ...]]]
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes, Any]
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, BooleanSpec, ChoiceSpec, NonPrimitiveTypeSpec]
# numpy seeds must be between 0 and 2**32 - 1
Seed = ("INT", {"default": 0, "min": 0, "max": 2 ** 32 - 1})
Seed31 = ("INT", {"default": 0, "min": 0, "max": 2 ** 31 - 1})
Seed64 = ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff, "control_after_generate": True})
SeedSpec = tuple[Literal["INT"], TypedDict("SeedSpecOptions", {"default": Literal[0], "min": Literal[0], "max": Literal[4294967295]})]
class HiddenSpec(TypedDict, total=True):
prompt: Literal["PROMPT"]
extra_pnginfo: Literal["EXTRA_PNGINFO"]
class InputTypes(TypedDict, total=True):
required: Dict[str, InputTypeSpec]
optional: NotRequired[Dict[str, InputTypeSpec]]
hidden: NotRequired[HiddenSpec]
ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
IsChangedMethod = Callable[[Type[Any], ...], str]
class FunctionReturnsUIVariables(TypedDict):
ui: dict
result: NotRequired[Sequence[Any]]
class SaveNodeResultT(TypedDict, total=True):
abs_path: NotRequired[str]
filename: str
subfolder: str
type: Literal["output", "input", "temp"]
SaveNodeResult = SaveNodeResultT
class UIImagesImagesResult(TypedDict, total=True):
images: List[SaveNodeResult]
class UIImagesResult(TypedDict, total=True):
ui: UIImagesImagesResult
result: NotRequired[Sequence[Any]]
animated: NotRequired[tuple[bool, ...]]
class UILatentsLatentsResult(TypedDict, total=True):
latents: List[SaveNodeResult]
class UILatentsResult(TypedDict, total=True):
ui: UILatentsLatentsResult
result: NotRequired[Sequence[Any]]
ValidatedNodeResult = Union[Tuple, UIImagesResult, UILatentsResult, FunctionReturnsUIVariables]
class CustomNode(Protocol):
@classmethod
def INPUT_TYPES(cls) -> InputTypes: ...
# Optional method signature for VALIDATE_INPUTS
VALIDATE_INPUTS: Optional[ClassVar[ValidateInputsMethod]]
RETURN_TYPES: ClassVar[Tuple[CommonReturnTypes, ...]]
RETURN_NAMES: Optional[ClassVar[Tuple[str, ...]]]
OUTPUT_IS_LIST: Optional[ClassVar[Tuple[bool, ...]]]
INPUT_IS_LIST: Optional[ClassVar[bool]]
FUNCTION: ClassVar[str]
CATEGORY: ClassVar[str]
OUTPUT_NODE: Optional[ClassVar[bool]]
INFERENCE_MODE: Optional[ClassVar[bool]]
@classmethod
def IS_CHANGED(cls, *args, **kwargs) -> str:
...
@classmethod
def __call__(cls, *args, **kwargs) -> 'CustomNode':
...
def check_lazy_status(self, *args, **kwargs) -> list[str]:
"""
Return a list of input names that need to be evaluated.
This function will be called if there are any lazy inputs which have not yet been
evaluated. As long as you return at least one field which has not yet been evaluated
(and more exist), this function will be called again once the value of the requested
field is available.
Any evaluated inputs will be passed as arguments to this function. Any unevaluated
inputs will have the value None.
"""
...
@dataclass
class ExportedNodes:
NODE_CLASS_MAPPINGS: MutableMapping[str, CustomNode] = field(default_factory=dict)
NODE_DISPLAY_NAME_MAPPINGS: MutableMapping[str, str] = field(default_factory=dict)
EXTENSION_WEB_DIRS: MutableMapping[str, str] = field(default_factory=dict)
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
self.NODE_CLASS_MAPPINGS.update(exported_nodes.NODE_CLASS_MAPPINGS)
self.NODE_DISPLAY_NAME_MAPPINGS.update(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS)
self.EXTENSION_WEB_DIRS.update(exported_nodes.EXTENSION_WEB_DIRS)
return self
def __len__(self):
return len(self.NODE_CLASS_MAPPINGS)
def __sub__(self, other: ExportedNodes):
exported_nodes = ExportedNodes().update(self)
for self_key in frozenset(exported_nodes.NODE_CLASS_MAPPINGS):
if self_key in other.NODE_CLASS_MAPPINGS:
exported_nodes.NODE_CLASS_MAPPINGS.pop(self_key)
if self_key in other.NODE_DISPLAY_NAME_MAPPINGS:
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.pop(self_key)
for self_key in frozenset(exported_nodes.EXTENSION_WEB_DIRS):
if self_key in other.EXTENSION_WEB_DIRS:
exported_nodes.EXTENSION_WEB_DIRS.pop(self_key)
return exported_nodes
def __add__(self, other):
exported_nodes = ExportedNodes().update(self)
return exported_nodes.update(other)
class _ExportedNodesAsChainMap(ExportedNodes):
@classmethod
def from_iter(cls, *exported_nodes: ExportedNodes):
en = _ExportedNodesAsChainMap()
en.NODE_CLASS_MAPPINGS = ChainMap(*[ncm.NODE_CLASS_MAPPINGS for ncm in exported_nodes])
en.NODE_DISPLAY_NAME_MAPPINGS = ChainMap(*[ncm.NODE_DISPLAY_NAME_MAPPINGS for ncm in exported_nodes])
en.EXTENSION_WEB_DIRS = ChainMap(*[ncm.EXTENSION_WEB_DIRS for ncm in exported_nodes])
return en
def update(self, exported_nodes: ExportedNodes) -> ExportedNodes:
self.NODE_CLASS_MAPPINGS = self.NODE_CLASS_MAPPINGS.new_child(exported_nodes.NODE_CLASS_MAPPINGS)
self.NODE_DISPLAY_NAME_MAPPINGS = self.NODE_DISPLAY_NAME_MAPPINGS.new_child(exported_nodes.NODE_DISPLAY_NAME_MAPPINGS)
self.EXTENSION_WEB_DIRS = self.EXTENSION_WEB_DIRS.new_child(exported_nodes.EXTENSION_WEB_DIRS)
return self
def exported_nodes_view(*exported_nodes: ExportedNodes) -> ExportedNodes:
"""Gets a view of all the provided exported nodes, concatenating them together using a ChainMap internally"""
return _ExportedNodesAsChainMap.from_iter(*exported_nodes)