mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Improve compatibility with comfyui-extra-models, improve API
This commit is contained in:
parent
8a3b49eb49
commit
3125366eda
@ -7,9 +7,8 @@ import sys
|
||||
import time
|
||||
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
|
||||
|
||||
from pkg_resources import resource_filename
|
||||
|
||||
from ..cli_args import args
|
||||
from ..component_model.files import get_package_as_path
|
||||
|
||||
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
|
||||
|
||||
@ -94,7 +93,7 @@ else:
|
||||
models_dir = os.path.join(base_path, "models")
|
||||
folder_names_and_paths = FolderNames(models_dir)
|
||||
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"})
|
||||
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), get_package_as_path("comfy.configs")], {".yaml"})
|
||||
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions))
|
||||
|
||||
@ -23,7 +23,6 @@ from PIL import Image
|
||||
from PIL.PngImagePlugin import PngInfo
|
||||
from aiohttp import web
|
||||
from can_ada import URL, parse as urlparse
|
||||
from pkg_resources import resource_filename
|
||||
from typing_extensions import NamedTuple
|
||||
|
||||
import comfy.interruption
|
||||
@ -38,6 +37,7 @@ from ..cmd import folder_paths
|
||||
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
|
||||
from ..component_model.executor_types import ExecutorToClientProgress
|
||||
from ..component_model.file_output_path import file_output_path
|
||||
from ..component_model.files import get_package_as_path
|
||||
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
|
||||
ExecutionStatus
|
||||
from ..digest import digest
|
||||
@ -115,7 +115,7 @@ class PromptServer(ExecutorToClientProgress):
|
||||
self.sockets = dict()
|
||||
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../web")
|
||||
if not os.path.exists(web_root_path):
|
||||
web_root_path = resource_filename('comfy', 'web/')
|
||||
web_root_path = get_package_as_path('comfy', 'web/')
|
||||
self.web_root = web_root_path
|
||||
routes = web.RouteTableDef()
|
||||
self.routes: web.RouteTableDef = routes
|
||||
|
||||
52
comfy/component_model/files.py
Normal file
52
comfy/component_model/files.py
Normal file
@ -0,0 +1,52 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import os
|
||||
from importlib import resources as resources
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def get_path_as_dict(config_dict_or_path: str | dict | None, config_path_inside_package: str, package: str = 'comfy') -> dict:
|
||||
"""
|
||||
Given a package and a filename inside the package, returns it as a JSON dict; or, returns the file pointed to by
|
||||
config_dict_or_path, when it is not None and when it exists
|
||||
|
||||
:param config_dict_or_path: a file path or dict pointing to a JSON file. If it exists, it is parsed and returned. Otherwise, when None, falls back to other defaults
|
||||
:param config_path_inside_package: a filename inside a package
|
||||
:param package: a package containing the file
|
||||
:return:
|
||||
"""
|
||||
config: dict | None = None
|
||||
|
||||
if config_dict_or_path is None:
|
||||
config_dict_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), config_path_inside_package)
|
||||
|
||||
if isinstance(config_dict_or_path, str):
|
||||
if config_dict_or_path.startswith("{"):
|
||||
config = json.loads(config_dict_or_path)
|
||||
else:
|
||||
if not os.path.exists(config_dict_or_path):
|
||||
with resources.as_file(resources.files(package) / config_path_inside_package) as config_path:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
with open(config_dict_or_path) as f:
|
||||
config = json.load(f)
|
||||
elif isinstance(config_dict_or_path, dict):
|
||||
config = config_dict_or_path
|
||||
|
||||
assert config is not None
|
||||
return config
|
||||
|
||||
|
||||
def get_package_as_path(package: str, subdir: Optional[str] = None) -> str:
|
||||
"""
|
||||
Gets the path on the file system to a package. This unpacks it completely.
|
||||
:param package: the package containing the files
|
||||
:param subdir: if specified, a subdirectory containing files (and not python packages), such as a web/ directory inside a package
|
||||
:return:
|
||||
"""
|
||||
traversable = resources.files(package)
|
||||
if subdir is not None:
|
||||
traversable = traversable / subdir
|
||||
return os.path.commonpath(list(map(str, traversable.iterdir())))
|
||||
@ -167,7 +167,6 @@ KNOWN_CHECKPOINTS = [
|
||||
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
|
||||
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
|
||||
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
|
||||
|
||||
]
|
||||
|
||||
KNOWN_UNCLIP_CHECKPOINTS = [
|
||||
@ -304,6 +303,10 @@ KNOWN_HUGGINGFACE_MODEL_REPOS = {
|
||||
"microsoft/Phi-3-mini-4k-instruct",
|
||||
}
|
||||
|
||||
KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = []
|
||||
|
||||
KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = []
|
||||
|
||||
|
||||
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
|
||||
if args.disable_known_models:
|
||||
|
||||
@ -26,7 +26,7 @@ from ..cli_args import args
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..execution_context import current_execution_context
|
||||
from ..images import open_image
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
||||
from ..nodes.common import MAX_RESOLUTION
|
||||
from .. import controlnet
|
||||
from ..open_exr import load_exr
|
||||
@ -799,7 +799,7 @@ class ControlNetApplyAdvanced:
|
||||
class UNETLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"),),
|
||||
return {"required": { "unet_name": (get_filename_list_with_downloadable("unet", KNOWN_UNET_MODELS),),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL",)
|
||||
FUNCTION = "load_unet"
|
||||
@ -807,14 +807,14 @@ class UNETLoader:
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_unet(self, unet_name):
|
||||
unet_path = folder_paths.get_full_path("unet", unet_name)
|
||||
unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS)
|
||||
model = sd.load_unet(unet_path)
|
||||
return (model,)
|
||||
|
||||
class CLIPLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"),),
|
||||
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS),),
|
||||
"type": (["stable_diffusion", "stable_cascade"], ),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP",)
|
||||
@ -823,11 +823,14 @@ class CLIPLoader:
|
||||
CATEGORY = "advanced/loaders"
|
||||
|
||||
def load_clip(self, clip_name, type="stable_diffusion"):
|
||||
clip_type = sd.CLIPType.STABLE_DIFFUSION
|
||||
if type == "stable_cascade":
|
||||
if type == "stable_diffusion":
|
||||
clip_type = sd.CLIPType.STABLE_DIFFUSION
|
||||
elif type == "stable_cascade":
|
||||
clip_type = sd.CLIPType.STABLE_CASCADE
|
||||
else:
|
||||
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
|
||||
|
||||
clip_path = folder_paths.get_full_path("clip", clip_name)
|
||||
clip_path = get_or_download("clip", clip_name, KNOWN_CLIP_MODELS)
|
||||
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
|
||||
return (clip,)
|
||||
|
||||
|
||||
@ -10,10 +10,10 @@ from functools import reduce
|
||||
from importlib.metadata import entry_points
|
||||
|
||||
from opentelemetry.trace import Span, Status, StatusCode
|
||||
from pkg_resources import resource_filename
|
||||
|
||||
from .package_typing import ExportedNodes
|
||||
from ..cmd.main_pre import tracer
|
||||
from ..component_model.files import get_package_as_path
|
||||
|
||||
_comfy_nodes: ExportedNodes = ExportedNodes()
|
||||
|
||||
@ -28,7 +28,7 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT
|
||||
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names)
|
||||
if web_directory:
|
||||
# load the extension resources path
|
||||
abs_web_directory = os.path.abspath(resource_filename(module.__name__, web_directory))
|
||||
abs_web_directory = os.path.abspath(get_package_as_path(module.__name__, web_directory))
|
||||
if not os.path.isdir(abs_web_directory):
|
||||
abs_web_directory = os.path.abspath(os.path.join(os.path.dirname(module.__file__), web_directory))
|
||||
if not os.path.isdir(abs_web_directory):
|
||||
|
||||
@ -1,8 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
import importlib.resources as resources
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import traceback
|
||||
@ -10,36 +8,12 @@ import zipfile
|
||||
from typing import Tuple, Sequence, TypeVar
|
||||
|
||||
import torch
|
||||
from pkg_resources import resource_filename
|
||||
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
|
||||
|
||||
from . import clip_model
|
||||
from . import model_management
|
||||
from . import ops
|
||||
|
||||
|
||||
def get_clip_config_dict(text_model_config_or_path: str | dict | None, text_model_config_path_in_comfy: str, package: str = 'comfy') -> dict:
|
||||
config: dict | None = None
|
||||
|
||||
if text_model_config_or_path is None:
|
||||
text_model_config_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), text_model_config_path_in_comfy)
|
||||
|
||||
if isinstance(text_model_config_or_path, str):
|
||||
if text_model_config_or_path.startswith("{"):
|
||||
config = json.loads(text_model_config_or_path)
|
||||
else:
|
||||
if not os.path.exists(text_model_config_or_path):
|
||||
with resources.as_file(resources.files(package) / text_model_config_path_in_comfy) as config_path:
|
||||
with open(config_path) as f:
|
||||
config = json.load(f)
|
||||
else:
|
||||
with open(text_model_config_or_path) as f:
|
||||
config = json.load(f)
|
||||
elif isinstance(text_model_config_or_path, dict):
|
||||
config = text_model_config_or_path
|
||||
|
||||
assert config is not None
|
||||
return config
|
||||
from .component_model.files import get_path_as_dict, get_package_as_path
|
||||
|
||||
|
||||
def gen_empty_tokens(special_tokens, length):
|
||||
@ -109,7 +83,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
|
||||
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
|
||||
assert layer in self.LAYERS
|
||||
|
||||
config = get_clip_config_dict(textmodel_json_config, "sd1_clip_config.json")
|
||||
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json")
|
||||
self.transformer = model_class(config, dtype, device, ops.manual_cast)
|
||||
self.num_layers = self.transformer.num_layers
|
||||
|
||||
@ -402,7 +376,7 @@ class SDTokenizer:
|
||||
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
|
||||
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
|
||||
# package based
|
||||
tokenizer_path = resource_filename('comfy', 'sd1_tokenizer/')
|
||||
tokenizer_path = get_package_as_path('comfy.sd1_tokenizer')
|
||||
self.tokenizer_class = tokenizer_class
|
||||
self.tokenizer_path = tokenizer_path
|
||||
self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from . import sd1_clip
|
||||
|
||||
from .sd1_clip import get_clip_config_dict
|
||||
from .component_model.files import get_path_as_dict
|
||||
|
||||
|
||||
class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
@ -9,7 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
|
||||
layer = "hidden"
|
||||
layer_idx = -2
|
||||
|
||||
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "sd2_clip_config.json")
|
||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "sd2_clip_config.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})
|
||||
|
||||
|
||||
|
||||
@ -3,7 +3,7 @@ import copy
|
||||
import torch
|
||||
|
||||
from . import sd1_clip
|
||||
from .sd1_clip import get_clip_config_dict
|
||||
from .component_model.files import get_path_as_dict
|
||||
|
||||
|
||||
class SDXLClipG(sd1_clip.SDClipModel):
|
||||
@ -12,7 +12,7 @@ class SDXLClipG(sd1_clip.SDClipModel):
|
||||
layer = "hidden"
|
||||
layer_idx = -2
|
||||
|
||||
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
|
||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
|
||||
|
||||
@ -91,7 +91,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
|
||||
|
||||
class StableCascadeClipG(sd1_clip.SDClipModel):
|
||||
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None):
|
||||
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
|
||||
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
|
||||
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
|
||||
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)
|
||||
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import contextlib
|
||||
import logging
|
||||
import math
|
||||
import os
|
||||
import os.path
|
||||
import random
|
||||
import struct
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, List, Dict
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
import torch
|
||||
from fastchat.model import get_conversation_template
|
||||
@ -31,16 +31,20 @@ class TransformersLoader(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (huggingface_repos(),)
|
||||
"ckpt_name": (huggingface_repos(),),
|
||||
"subfolder": ("STRING", {})
|
||||
}
|
||||
}
|
||||
|
||||
RETURN_TYPES = "MODEL",
|
||||
FUNCTION = "execute"
|
||||
|
||||
def execute(self, ckpt_name: str):
|
||||
def execute(self, ckpt_name: str, subfolder: Optional[str] = None):
|
||||
hub_kwargs = {}
|
||||
if subfolder is not None and subfolder != "":
|
||||
hub_kwargs["subfolder"] = subfolder
|
||||
with comfy_tqdm():
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True)
|
||||
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
|
||||
tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
|
||||
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer)
|
||||
return model_managed,
|
||||
|
||||
@ -215,7 +215,7 @@ class StringPosixPathJoin(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
f"value{i}": ("STRING", {"default": "", "multiline": False}) for i in range(5)
|
||||
f"value{i}": ("STRING", {"default": "", "multiline": False, "forceInput": True}) for i in range(5)
|
||||
}
|
||||
}
|
||||
|
||||
@ -284,7 +284,7 @@ class DevNullUris(CustomNode):
|
||||
class StringJoin(CustomNode):
|
||||
@classmethod
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
required = {f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5)}
|
||||
required = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)}
|
||||
required["separator"] = ("STRING", {"default": "_"})
|
||||
return {
|
||||
"required": required
|
||||
@ -304,7 +304,7 @@ class StringToUri(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("STRING", {"default": "", "multiline": True}),
|
||||
"value": ("STRING", {"default": "", "multiline": True, "forceInput": True}),
|
||||
"batch": ("INT", {"default": 1})
|
||||
}
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user