Improve compatibility with comfyui-extra-models, improve API

This commit is contained in:
doctorpangloss 2024-05-30 16:50:34 -07:00
parent 8a3b49eb49
commit 3125366eda
12 changed files with 92 additions and 56 deletions

View File

@ -7,9 +7,8 @@ import sys
import time
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
from pkg_resources import resource_filename
from ..cli_args import args
from ..component_model.files import get_package_as_path
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
@ -94,7 +93,7 @@ else:
models_dir = os.path.join(base_path, "models")
folder_names_and_paths = FolderNames(models_dir)
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"})
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), get_package_as_path("comfy.configs")], {".yaml"})
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions))
folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions))

View File

@ -23,7 +23,6 @@ from PIL import Image
from PIL.PngImagePlugin import PngInfo
from aiohttp import web
from can_ada import URL, parse as urlparse
from pkg_resources import resource_filename
from typing_extensions import NamedTuple
import comfy.interruption
@ -38,6 +37,7 @@ from ..cmd import folder_paths
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.file_output_path import file_output_path
from ..component_model.files import get_package_as_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
ExecutionStatus
from ..digest import digest
@ -115,7 +115,7 @@ class PromptServer(ExecutorToClientProgress):
self.sockets = dict()
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../web")
if not os.path.exists(web_root_path):
web_root_path = resource_filename('comfy', 'web/')
web_root_path = get_package_as_path('comfy', 'web/')
self.web_root = web_root_path
routes = web.RouteTableDef()
self.routes: web.RouteTableDef = routes

View File

@ -0,0 +1,52 @@
from __future__ import annotations
import json
import os
from importlib import resources as resources
from typing import Optional
def get_path_as_dict(config_dict_or_path: str | dict | None, config_path_inside_package: str, package: str = 'comfy') -> dict:
"""
Given a package and a filename inside the package, returns it as a JSON dict; or, returns the file pointed to by
config_dict_or_path, when it is not None and when it exists
:param config_dict_or_path: a file path or dict pointing to a JSON file. If it exists, it is parsed and returned. Otherwise, when None, falls back to other defaults
:param config_path_inside_package: a filename inside a package
:param package: a package containing the file
:return:
"""
config: dict | None = None
if config_dict_or_path is None:
config_dict_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), config_path_inside_package)
if isinstance(config_dict_or_path, str):
if config_dict_or_path.startswith("{"):
config = json.loads(config_dict_or_path)
else:
if not os.path.exists(config_dict_or_path):
with resources.as_file(resources.files(package) / config_path_inside_package) as config_path:
with open(config_path) as f:
config = json.load(f)
else:
with open(config_dict_or_path) as f:
config = json.load(f)
elif isinstance(config_dict_or_path, dict):
config = config_dict_or_path
assert config is not None
return config
def get_package_as_path(package: str, subdir: Optional[str] = None) -> str:
"""
Gets the path on the file system to a package. This unpacks it completely.
:param package: the package containing the files
:param subdir: if specified, a subdirectory containing files (and not python packages), such as a web/ directory inside a package
:return:
"""
traversable = resources.files(package)
if subdir is not None:
traversable = traversable / subdir
return os.path.commonpath(list(map(str, traversable.iterdir())))

View File

@ -167,7 +167,6 @@ KNOWN_CHECKPOINTS = [
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
]
KNOWN_UNCLIP_CHECKPOINTS = [
@ -304,6 +303,10 @@ KNOWN_HUGGINGFACE_MODEL_REPOS = {
"microsoft/Phi-3-mini-4k-instruct",
}
KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = []
KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = []
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
if args.disable_known_models:

View File

@ -26,7 +26,7 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..execution_context import current_execution_context
from ..images import open_image
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
from ..nodes.common import MAX_RESOLUTION
from .. import controlnet
from ..open_exr import load_exr
@ -799,7 +799,7 @@ class ControlNetApplyAdvanced:
class UNETLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"),),
return {"required": { "unet_name": (get_filename_list_with_downloadable("unet", KNOWN_UNET_MODELS),),
}}
RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet"
@ -807,14 +807,14 @@ class UNETLoader:
CATEGORY = "advanced/loaders"
def load_unet(self, unet_name):
unet_path = folder_paths.get_full_path("unet", unet_name)
unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS)
model = sd.load_unet(unet_path)
return (model,)
class CLIPLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"),),
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS),),
"type": (["stable_diffusion", "stable_cascade"], ),
}}
RETURN_TYPES = ("CLIP",)
@ -823,11 +823,14 @@ class CLIPLoader:
CATEGORY = "advanced/loaders"
def load_clip(self, clip_name, type="stable_diffusion"):
clip_type = sd.CLIPType.STABLE_DIFFUSION
if type == "stable_cascade":
if type == "stable_diffusion":
clip_type = sd.CLIPType.STABLE_DIFFUSION
elif type == "stable_cascade":
clip_type = sd.CLIPType.STABLE_CASCADE
else:
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
clip_path = folder_paths.get_full_path("clip", clip_name)
clip_path = get_or_download("clip", clip_name, KNOWN_CLIP_MODELS)
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,)

View File

@ -10,10 +10,10 @@ from functools import reduce
from importlib.metadata import entry_points
from opentelemetry.trace import Span, Status, StatusCode
from pkg_resources import resource_filename
from .package_typing import ExportedNodes
from ..cmd.main_pre import tracer
from ..component_model.files import get_package_as_path
_comfy_nodes: ExportedNodes = ExportedNodes()
@ -28,7 +28,7 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names)
if web_directory:
# load the extension resources path
abs_web_directory = os.path.abspath(resource_filename(module.__name__, web_directory))
abs_web_directory = os.path.abspath(get_package_as_path(module.__name__, web_directory))
if not os.path.isdir(abs_web_directory):
abs_web_directory = os.path.abspath(os.path.join(os.path.dirname(module.__file__), web_directory))
if not os.path.isdir(abs_web_directory):

View File

@ -1,8 +1,6 @@
from __future__ import annotations
import copy
import importlib.resources as resources
import json
import logging
import os
import traceback
@ -10,36 +8,12 @@ import zipfile
from typing import Tuple, Sequence, TypeVar
import torch
from pkg_resources import resource_filename
from transformers import CLIPTokenizer, PreTrainedTokenizerBase
from . import clip_model
from . import model_management
from . import ops
def get_clip_config_dict(text_model_config_or_path: str | dict | None, text_model_config_path_in_comfy: str, package: str = 'comfy') -> dict:
config: dict | None = None
if text_model_config_or_path is None:
text_model_config_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), text_model_config_path_in_comfy)
if isinstance(text_model_config_or_path, str):
if text_model_config_or_path.startswith("{"):
config = json.loads(text_model_config_or_path)
else:
if not os.path.exists(text_model_config_or_path):
with resources.as_file(resources.files(package) / text_model_config_path_in_comfy) as config_path:
with open(config_path) as f:
config = json.load(f)
else:
with open(text_model_config_or_path) as f:
config = json.load(f)
elif isinstance(text_model_config_or_path, dict):
config = text_model_config_or_path
assert config is not None
return config
from .component_model.files import get_path_as_dict, get_package_as_path
def gen_empty_tokens(special_tokens, length):
@ -109,7 +83,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
assert layer in self.LAYERS
config = get_clip_config_dict(textmodel_json_config, "sd1_clip_config.json")
config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json")
self.transformer = model_class(config, dtype, device, ops.manual_cast)
self.num_layers = self.transformer.num_layers
@ -402,7 +376,7 @@ class SDTokenizer:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
# package based
tokenizer_path = resource_filename('comfy', 'sd1_tokenizer/')
tokenizer_path = get_package_as_path('comfy.sd1_tokenizer')
self.tokenizer_class = tokenizer_class
self.tokenizer_path = tokenizer_path
self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path)

View File

@ -1,6 +1,6 @@
from . import sd1_clip
from .sd1_clip import get_clip_config_dict
from .component_model.files import get_path_as_dict
class SD2ClipHModel(sd1_clip.SDClipModel):
@ -9,7 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
layer = "hidden"
layer_idx = -2
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "sd2_clip_config.json")
textmodel_json_config = get_path_as_dict(textmodel_json_config, "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})

View File

@ -3,7 +3,7 @@ import copy
import torch
from . import sd1_clip
from .sd1_clip import get_clip_config_dict
from .component_model.files import get_path_as_dict
class SDXLClipG(sd1_clip.SDClipModel):
@ -12,7 +12,7 @@ class SDXLClipG(sd1_clip.SDClipModel):
layer = "hidden"
layer_idx = -2
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
@ -91,7 +91,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None):
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json")
textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import contextlib
import logging
import math
import os
import os.path
import random
import struct

View File

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, List, Dict
from typing import Any, Dict, Optional
import torch
from fastchat.model import get_conversation_template
@ -31,16 +31,20 @@ class TransformersLoader(CustomNode):
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (huggingface_repos(),)
"ckpt_name": (huggingface_repos(),),
"subfolder": ("STRING", {})
}
}
RETURN_TYPES = "MODEL",
FUNCTION = "execute"
def execute(self, ckpt_name: str):
def execute(self, ckpt_name: str, subfolder: Optional[str] = None):
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
with comfy_tqdm():
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer)
return model_managed,

View File

@ -215,7 +215,7 @@ class StringPosixPathJoin(CustomNode):
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
f"value{i}": ("STRING", {"default": "", "multiline": False}) for i in range(5)
f"value{i}": ("STRING", {"default": "", "multiline": False, "forceInput": True}) for i in range(5)
}
}
@ -284,7 +284,7 @@ class DevNullUris(CustomNode):
class StringJoin(CustomNode):
@classmethod
def INPUT_TYPES(cls) -> InputTypes:
required = {f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5)}
required = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)}
required["separator"] = ("STRING", {"default": "_"})
return {
"required": required
@ -304,7 +304,7 @@ class StringToUri(CustomNode):
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("STRING", {"default": "", "multiline": True}),
"value": ("STRING", {"default": "", "multiline": True, "forceInput": True}),
"batch": ("INT", {"default": 1})
}
}