Fix CLIPLoader node, fix CustomNode typing, improve digest

This commit is contained in:
Benjamin Berman 2024-02-29 15:54:42 -08:00
parent bae2068111
commit e6623a1359
3 changed files with 93 additions and 38 deletions

View File

@ -2,10 +2,19 @@ from __future__ import annotations
import hashlib
import json
from typing_extensions import Buffer
def digest(data: dict | str) -> str:
json_str = data if isinstance(data, str) else json.dumps(data, separators=(',', ':'))
def digest(data: dict | str | Buffer) -> str:
hash_object = hashlib.sha256()
if isinstance(data, Buffer):
hash_object.update(data)
else:
if isinstance(data, str):
json_str = data
elif isinstance(data, dict):
json_str = json.dumps(data, separators=(',', ':'))
else:
raise RuntimeError("invalid data type")
hash_object.update(json_str.encode())
return hash_object.hexdigest()

View File

@ -833,9 +833,9 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name, type="stable_diffusion"):
clip_type = comfy.sd.CLIPType.STABLE_DIFFUSION
clip_type = sd.CLIPType.STABLE_DIFFUSION
if type == "stable_cascade":
clip_type = comfy.sd.CLIPType.STABLE_CASCADE
clip_type = sd.CLIPType.STABLE_CASCADE
clip_path = folder_paths.get_full_path("clip", clip_name)
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)

View File

@ -1,50 +1,96 @@
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TypedDict, Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable
from typing_extensions import TypedDict, NotRequired, Generic
from typing import Union, Optional, Sequence, Dict, ClassVar, Protocol, Tuple, TypeVar, Any, Literal, \
Callable, List
T = TypeVar('T')
class NumberSpecOptions(TypedDict, total=False):
default: Union[int, float]
min: Union[int, float]
max: Union[int, float]
step: Union[int, float]
round: int
class IntSpecOptions(TypedDict, total=True):
default: int
min: int
max: int
step: NotRequired[int]
display: NotRequired[Literal["number", "slider"]]
IntSpec = Dict[str, Union[
Literal["INT"],
Tuple[Literal["INT"], Dict[str, Union[int, float, str]]]
]]
FloatSpec = Dict[str, Union[
Literal["FLOAT"],
Tuple[Literal["FLOAT"], Dict[str, Union[int, float, str]]]
]]
StringSpec = Dict[str, Union[
Literal["STRING"],
Tuple[Literal["STRING"], Dict[str, str]]
]]
ChoiceSpec = Dict[str, Union[
Sequence[str], # Directly a list of choices
Tuple[Sequence[str], Dict[str, Any]] # Choices with additional specifications
]]
ComplexInputSpec = Dict[str, Any]
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, ChoiceSpec, ComplexInputSpec]
class FloatSpecOptions(TypedDict, total=True):
default: float
min: float
max: float
step: NotRequired[float]
round: NotRequired[float]
display: NotRequired[Literal["number", "slider"]]
class InputTypes(Protocol):
class StringSpecOptions(TypedDict, total=True):
multiline: NotRequired[bool]
default: NotRequired[str]
# todo: analyze the base_nodes for these types
CommonReturnTypes = Union[
Literal["IMAGE", "STRING", "INT", "FLOAT", "CONDITIONING", "LATENT", "MASK", "MODEL", "VAE", "CLIP"], str]
IntSpec = Tuple[Literal["INT"], IntSpecOptions]
FloatSpec = Tuple[Literal["FLOAT"], FloatSpecOptions]
StringSpec = Tuple[Literal["STRING"], StringSpecOptions]
ChoiceSpec = Tuple[Union[Sequence[str], Sequence[float], Sequence[int]]]
NonPrimitiveTypeSpec = Tuple[CommonReturnTypes]
InputTypeSpec = Union[IntSpec, FloatSpec, StringSpec, ChoiceSpec, NonPrimitiveTypeSpec]
class HiddenSpec(TypedDict, total=True):
prompt: Literal["PROMPT"]
extra_pnginfo: Literal["EXTRA_PNGINFO"]
class InputTypes(TypedDict, total=True):
required: Dict[str, InputTypeSpec]
optional: Optional[Dict[str, InputTypeSpec]]
hidden: Optional[Dict[str, InputTypeSpec]]
optional: NotRequired[Dict[str, InputTypeSpec]]
hidden: NotRequired[HiddenSpec]
ValidateInputsMethod = Optional[Callable[..., Union[bool, str]]]
class FunctionReturnsUIVariables(TypedDict):
ui: dict
class SaveNodeResult(TypedDict, total=True):
abs_path: NotRequired[str]
filename: str
subfolder: str
type: Literal["output", "input", "temp"]
class UIImagesImagesResult(TypedDict, total=True):
images: List[SaveNodeResult]
class UIImagesResult(TypedDict, total=True):
ui: UIImagesImagesResult
class UILatentsLatentsResult(TypedDict, total=True):
latents: List[SaveNodeResult]
class UILatentsResult(TypedDict, total=True):
ui: UILatentsLatentsResult
ValidatedNodeResult = Union[Tuple, UIImagesResult, UILatentsResult, FunctionReturnsUIVariables]
class CustomNode(Protocol):
@classmethod
def INPUT_TYPES(cls) -> InputTypes: ...
@ -52,9 +98,9 @@ class CustomNode(Protocol):
# Optional method signature for VALIDATE_INPUTS
VALIDATE_INPUTS: ClassVar[ValidateInputsMethod] = None
RETURN_TYPES: ClassVar[Sequence[str]]
RETURN_NAMES: Optional[ClassVar[Tuple[str]]]
OUTPUT_IS_LIST: Optional[ClassVar[Sequence[bool]]]
RETURN_TYPES: ClassVar[Tuple[CommonReturnTypes, ...]]
RETURN_NAMES: Optional[ClassVar[Tuple[str, ...]]]
OUTPUT_IS_LIST: Optional[ClassVar[Tuple[bool, ...]]]
INPUT_IS_LIST: Optional[ClassVar[bool]]
FUNCTION: ClassVar[str]
CATEGORY: ClassVar[str]