mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-24 21:30:15 +08:00
Improve compatibility with comfyui-extra-models, improve API
This commit is contained in:
parent
8a3b49eb49
commit
3125366eda
@ -7,9 +7,8 @@ import sys
|
|||||||
import time
|
import time
|
||||||
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
|
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
|
||||||
|
|
||||||
from pkg_resources import resource_filename
|
|
||||||
|
|
||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
|
from ..component_model.files import get_package_as_path
|
||||||
|
|
||||||
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
|
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
|
||||||
|
|
||||||
@ -94,7 +93,7 @@ else:
|
|||||||
models_dir = os.path.join(base_path, "models")
|
models_dir = os.path.join(base_path, "models")
|
||||||
folder_names_and_paths = FolderNames(models_dir)
|
folder_names_and_paths = FolderNames(models_dir)
|
||||||
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
|
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
|
||||||
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"})
|
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), get_package_as_path("comfy.configs")], {".yaml"})
|
||||||
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
|
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
|
||||||
folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions))
|
folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions))
|
||||||
folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions))
|
folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions))
|
||||||
|
|||||||
@ -23,7 +23,6 @@ from PIL import Image
|
|||||||
from PIL.PngImagePlugin import PngInfo
|
from PIL.PngImagePlugin import PngInfo
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from can_ada import URL, parse as urlparse
|
from can_ada import URL, parse as urlparse
|
||||||
from pkg_resources import resource_filename
|
|
||||||
from typing_extensions import NamedTuple
|
from typing_extensions import NamedTuple
|
||||||
|
|
||||||
import comfy.interruption
|
import comfy.interruption
|
||||||
@ -38,6 +37,7 @@ from ..cmd import folder_paths
|
|||||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
|
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
|
||||||
from ..component_model.executor_types import ExecutorToClientProgress
|
from ..component_model.executor_types import ExecutorToClientProgress
|
||||||
from ..component_model.file_output_path import file_output_path
|
from ..component_model.file_output_path import file_output_path
|
||||||
|
from ..component_model.files import get_package_as_path
|
||||||
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
|
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
|
||||||
ExecutionStatus
|
ExecutionStatus
|
||||||
from ..digest import digest
|
from ..digest import digest
|
||||||
@ -115,7 +115,7 @@ class PromptServer(ExecutorToClientProgress):
|
|||||||
self.sockets = dict()
|
self.sockets = dict()
|
||||||
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../web")
|
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../web")
|
||||||
if not os.path.exists(web_root_path):
|
if not os.path.exists(web_root_path):
|
||||||
web_root_path = resource_filename('comfy', 'web/')
|
web_root_path = get_package_as_path('comfy', 'web/')
|
||||||
self.web_root = web_root_path
|
self.web_root = web_root_path
|
||||||
routes = web.RouteTableDef()
|
routes = web.RouteTableDef()
|
||||||
self.routes: web.RouteTableDef = routes
|
self.routes: web.RouteTableDef = routes
|
||||||
|
|||||||
52
comfy/component_model/files.py
Normal file
52
comfy/component_model/files.py
Normal file
@ -0,0 +1,52 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
from importlib import resources as resources
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def get_path_as_dict(config_dict_or_path: str | dict | None, config_path_inside_package: str, package: str = 'comfy') -> dict:
|
||||||
|
"""
|
||||||
|
Given a package and a filename inside the package, returns it as a JSON dict; or, returns the file pointed to by
|
||||||
|
config_dict_or_path, when it is not None and when it exists
|
||||||
|
|
||||||
|
:param config_dict_or_path: a file path or dict pointing to a JSON file. If it exists, it is parsed and returned. Otherwise, when None, falls back to other defaults
|
||||||
|
:param config_path_inside_package: a filename inside a package
|
||||||
|
:param package: a package containing the file
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
config: dict | None = None
|
||||||
|
|
||||||
|
if config_dict_or_path is None:
|
||||||
|
config_dict_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), config_path_inside_package)
|
||||||
|
|
||||||
|
if isinstance(config_dict_or_path, str):
|
||||||
|
if config_dict_or_path.startswith("{"):
|
||||||
|
config = json.loads(config_dict_or_path)
|
||||||
|
else:
|
||||||
|
if not os.path.exists(config_dict_or_path):
|
||||||
|
with resources.as_file(resources.files(package) / config_path_inside_package) as config_path:
|
||||||
|
with open(config_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
else:
|
||||||
|
with open(config_dict_or_path) as f:
|
||||||
|
config = json.load(f)
|
||||||
|
elif isinstance(config_dict_or_path, dict):
|
||||||
|
config = config_dict_or_path
|
||||||
|
|
||||||
|
assert config is not None
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def get_package_as_path(package: str, subdir: Optional[str] = None) -> str:
|
||||||
|
"""
|
||||||
|
Gets the path on the file system to a package. This unpacks it completely.
|
||||||
|
:param package: the package containing the files
|
||||||
|
:param subdir: if specified, a subdirectory containing files (and not python packages), such as a web/ directory inside a package
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
traversable = resources.files(package)
|
||||||
|
if subdir is not None:
|
||||||
|
traversable = traversable / subdir
|
||||||
|
return os.path.commonpath(list(map(str, traversable.iterdir())))
|
||||||
@ -167,7 +167,6 @@ KNOWN_CHECKPOINTS = [
|
|||||||
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
|
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
|
||||||
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
|
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
|
||||||
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
|
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
|
||||||
|
|
||||||
]
|
]
|
||||||
|
|
||||||
KNOWN_UNCLIP_CHECKPOINTS = [
|
KNOWN_UNCLIP_CHECKPOINTS = [
|
||||||
@ -304,6 +303,10 @@ KNOWN_HUGGINGFACE_MODEL_REPOS = {
|
|||||||
"microsoft/Phi-3-mini-4k-instruct",
|
"microsoft/Phi-3-mini-4k-instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = []
|
||||||
|
|
||||||
|
KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = []
|
||||||
|
|
||||||
|
|
||||||
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
|
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
|
||||||
if args.disable_known_models:
|
if args.disable_known_models:
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from ..cli_args import args
|
|||||||
from ..cmd import folder_paths, latent_preview
|
from ..cmd import folder_paths, latent_preview
|
||||||
from ..execution_context import current_execution_context
|
from ..execution_context import current_execution_context
|
||||||
from ..images import open_image
|
from ..images import open_image
|
||||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos
|
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
||||||
from ..nodes.common import MAX_RESOLUTION
|
from ..nodes.common import MAX_RESOLUTION
|
||||||
from .. import controlnet
|
from .. import controlnet
|
||||||
from ..open_exr import load_exr
|
from ..open_exr import load_exr
|
||||||
@ -799,7 +799,7 @@ class ControlNetApplyAdvanced:
|
|||||||
class UNETLoader:
|
class UNETLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"),),
|
return {"required": { "unet_name": (get_filename_list_with_downloadable("unet", KNOWN_UNET_MODELS),),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "load_unet"
|
FUNCTION = "load_unet"
|
||||||
@ -807,14 +807,14 @@ class UNETLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_unet(self, unet_name):
|
def load_unet(self, unet_name):
|
||||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS)
|
||||||
model = sd.load_unet(unet_path)
|
model = sd.load_unet(unet_path)
|
||||||
return (model,)
|
return (model,)
|
||||||
|
|
||||||
class CLIPLoader:
|
class CLIPLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"),),
|
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS),),
|
||||||
"type": (["stable_diffusion", "stable_cascade"], ),
|
"type": (["stable_diffusion", "stable_cascade"], ),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP",)
|
RETURN_TYPES = ("CLIP",)
|
||||||
@ -823,11 +823,14 @@ class CLIPLoader:
|
|||||||
CATEGORY = "advanced/loaders"
|
CATEGORY = "advanced/loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name, type="stable_diffusion"):
|
def load_clip(self, clip_name, type="stable_diffusion"):
|
||||||
clip_type = sd.CLIPType.STABLE_DIFFUSION
|
if type == "stable_diffusion":
|
||||||
if type == "stable_cascade":
|
clip_type = sd.CLIPType.STABLE_DIFFUSION
|
||||||
|
elif type == "stable_cascade":
|
||||||
clip_type = sd.CLIPType.STABLE_CASCADE
|
clip_type = sd.CLIPType.STABLE_CASCADE
|
||||||
|
else:
|
||||||
|
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
|
||||||
|
|
||||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
clip_path = get_or_download("clip", clip_name, KNOWN_CLIP_MODELS)
|
||||||
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||||
return (clip,)
|
return (clip,)
|
||||||
|
|
||||||
|
|||||||
@ -10,10 +10,10 @@ from functools import reduce
|
|||||||
from importlib.metadata import entry_points
|
from importlib.metadata import entry_points
|
||||||
|
|
||||||
from opentelemetry.trace import Span, Status, StatusCode
|
from opentelemetry.trace import Span, Status, StatusCode
|
||||||
from pkg_resources import resource_filename
|
|
||||||
|
|
||||||
from .package_typing import ExportedNodes
|
from .package_typing import ExportedNodes
|
||||||
from ..cmd.main_pre import tracer
|
from ..cmd.main_pre import tracer
|
||||||
|
from ..component_model.files import get_package_as_path
|
||||||
|
|
||||||
_comfy_nodes: ExportedNodes = ExportedNodes()
|
_comfy_nodes: ExportedNodes = ExportedNodes()
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT
|
|||||||
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names)
|
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names)
|
||||||
if web_directory:
|
if web_directory:
|
||||||
# load the extension resources path
|
# load the extension resources path
|
||||||
abs_web_directory = os.path.abspath(resource_filename(module.__name__, web_directory))
|
abs_web_directory = os.path.abspath(get_package_as_path(module.__name__, web_directory))
|
||||||
if not os.path.isdir(abs_web_directory):
|
if not os.path.isdir(abs_web_directory):
|
||||||
abs_web_directory = os.path.abspath(os.path.join(os.path.dirname(module.__file__), web_directory))
|
abs_web_directory = os.path.abspath(os.path.join(os.path.dirname(module.__file__), web_directory))
|
||||||
if not os.path.isdir(abs_web_directory):
|
if not os.path.isdir(abs_web_directory):
|
||||||
|
|||||||
@ -1,8 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import copy
|
import copy
|
||||||
import importlib.resources as resources
|
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import traceback
|
import traceback
|
||||||
@ -10,36 +8,12 @@ import zipfile
|
|||||||
from typing import Tuple, Sequence, TypeVar
|
from typing import Tuple, Sequence, TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from pkg_resources import resource_filename
|
|
||||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
|
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
|
||||||
|
|
||||||
from . import clip_model
|
from . import clip_model
|
||||||
from . import model_management
|
from . import model_management
|
||||||
from . import ops
|
from . import ops
|
||||||
|
from .component_model.files import get_path_as_dict, get_package_as_path
|
||||||
|
|
||||||
def get_clip_config_dict(text_model_config_or_path: str | dict | None, text_model_config_path_in_comfy: str, package: str = 'comfy') -> dict:
|
|
||||||
config: dict | None = None
|
|
||||||
|
|
||||||
if text_model_config_or_path is None:
|
|
||||||
text_model_config_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), text_model_config_path_in_comfy)
|
|
||||||
|
|
||||||
if isinstance(text_model_config_or_path, str):
|
|
||||||
if text_model_config_or_path.startswith("{"):
|
|
||||||
config = json.loads(text_model_config_or_path)
|
|
||||||
else:
|
|
||||||
if not os.path.exists(text_model_config_or_path):
|
|
||||||
with resources.as_file(resources.files(package) / text_model_config_path_in_comfy) as config_path:
|
|
||||||
with open(config_path) as f:
|
|
||||||
config = json.load(f)
|
|
||||||
else:
|
|
||||||
with open(text_model_config_or_path) as f:
|
|
||||||
config = json.load(f)
|
|
||||||
elif isinstance(text_model_config_or_path, dict):
|
|
||||||
config = text_model_config_or_path
|
|
||||||
|
|
||||||
assert config is not None
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def gen_empty_tokens(special_tokens, length):
|
def gen_empty_tokens(special_tokens, length):
|
||||||
@ -109,7 +83,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
|||||||
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
||||||
assert layer in self.LAYERS
|
assert layer in self.LAYERS
|
||||||
|
|
||||||
config = get_clip_config_dict(textmodel_json_config, "sd1_clip_config.json")
|
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json")
|
||||||
self.transformer = model_class(config, dtype, device, ops.manual_cast)
|
self.transformer = model_class(config, dtype, device, ops.manual_cast)
|
||||||
self.num_layers = self.transformer.num_layers
|
self.num_layers = self.transformer.num_layers
|
||||||
|
|
||||||
@ -402,7 +376,7 @@ class SDTokenizer:
|
|||||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||||
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
|
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
|
||||||
# package based
|
# package based
|
||||||
tokenizer_path = resource_filename('comfy', 'sd1_tokenizer/')
|
tokenizer_path = get_package_as_path('comfy.sd1_tokenizer')
|
||||||
self.tokenizer_class = tokenizer_class
|
self.tokenizer_class = tokenizer_class
|
||||||
self.tokenizer_path = tokenizer_path
|
self.tokenizer_path = tokenizer_path
|
||||||
self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path)
|
self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path)
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
|
|
||||||
from .sd1_clip import get_clip_config_dict
|
from .component_model.files import get_path_as_dict
|
||||||
|
|
||||||
|
|
||||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||||
@ -9,7 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
|
|||||||
layer = "hidden"
|
layer = "hidden"
|
||||||
layer_idx = -2
|
layer_idx = -2
|
||||||
|
|
||||||
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "sd2_clip_config.json")
|
textmodel_json_config = get_path_as_dict(textmodel_json_config, "sd2_clip_config.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ import copy
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from . import sd1_clip
|
from . import sd1_clip
|
||||||
from .sd1_clip import get_clip_config_dict
|
from .component_model.files import get_path_as_dict
|
||||||
|
|
||||||
|
|
||||||
class SDXLClipG(sd1_clip.SDClipModel):
|
class SDXLClipG(sd1_clip.SDClipModel):
|
||||||
@ -12,7 +12,7 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
|||||||
layer = "hidden"
|
layer = "hidden"
|
||||||
layer_idx = -2
|
layer_idx = -2
|
||||||
|
|
||||||
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
|
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||||
|
|
||||||
@ -91,7 +91,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
|||||||
|
|
||||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None):
|
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None):
|
||||||
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
|
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
|
||||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
||||||
|
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import contextlib
|
import contextlib
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import os
|
||||||
import os.path
|
import os.path
|
||||||
import random
|
import random
|
||||||
import struct
|
import struct
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
from typing import Any, List, Dict
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fastchat.model import get_conversation_template
|
from fastchat.model import get_conversation_template
|
||||||
@ -31,16 +31,20 @@ class TransformersLoader(CustomNode):
|
|||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"ckpt_name": (huggingface_repos(),)
|
"ckpt_name": (huggingface_repos(),),
|
||||||
|
"subfolder": ("STRING", {})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
RETURN_TYPES = "MODEL",
|
RETURN_TYPES = "MODEL",
|
||||||
FUNCTION = "execute"
|
FUNCTION = "execute"
|
||||||
|
|
||||||
def execute(self, ckpt_name: str):
|
def execute(self, ckpt_name: str, subfolder: Optional[str] = None):
|
||||||
|
hub_kwargs = {}
|
||||||
|
if subfolder is not None and subfolder != "":
|
||||||
|
hub_kwargs["subfolder"] = subfolder
|
||||||
with comfy_tqdm():
|
with comfy_tqdm():
|
||||||
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True)
|
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
|
||||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
|
tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
|
||||||
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer)
|
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer)
|
||||||
return model_managed,
|
return model_managed,
|
||||||
|
|||||||
@ -215,7 +215,7 @@ class StringPosixPathJoin(CustomNode):
|
|||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
f"value{i}": ("STRING", {"default": "", "multiline": False}) for i in range(5)
|
f"value{i}": ("STRING", {"default": "", "multiline": False, "forceInput": True}) for i in range(5)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -284,7 +284,7 @@ class DevNullUris(CustomNode):
|
|||||||
class StringJoin(CustomNode):
|
class StringJoin(CustomNode):
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
required = {f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5)}
|
required = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)}
|
||||||
required["separator"] = ("STRING", {"default": "_"})
|
required["separator"] = ("STRING", {"default": "_"})
|
||||||
return {
|
return {
|
||||||
"required": required
|
"required": required
|
||||||
@ -304,7 +304,7 @@ class StringToUri(CustomNode):
|
|||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"value": ("STRING", {"default": "", "multiline": True}),
|
"value": ("STRING", {"default": "", "multiline": True, "forceInput": True}),
|
||||||
"batch": ("INT", {"default": 1})
|
"batch": ("INT", {"default": 1})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user