Improve compatibility with comfyui-extra-models, improve API

This commit is contained in:
doctorpangloss 2024-05-30 16:50:34 -07:00
parent 8a3b49eb49
commit 3125366eda
12 changed files with 92 additions and 56 deletions

View File

@ -7,9 +7,8 @@ import sys
import time import time
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
from pkg_resources import resource_filename
from ..cli_args import args from ..cli_args import args
from ..component_model.files import get_package_as_path
supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl']) supported_pt_extensions = frozenset(['.ckpt', '.pt', '.bin', '.pth', '.safetensors', '.pkl'])
@ -94,7 +93,7 @@ else:
models_dir = os.path.join(base_path, "models") models_dir = os.path.join(base_path, "models")
folder_names_and_paths = FolderNames(models_dir) folder_names_and_paths = FolderNames(models_dir)
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions)) folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"}) folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), get_package_as_path("comfy.configs")], {".yaml"})
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions)) folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions)) folder_names_and_paths["vae"] = FolderPathsTuple("vae", [os.path.join(models_dir, "vae")], set(supported_pt_extensions))
folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions)) folder_names_and_paths["clip"] = FolderPathsTuple("clip", [os.path.join(models_dir, "clip")], set(supported_pt_extensions))

View File

@ -23,7 +23,6 @@ from PIL import Image
from PIL.PngImagePlugin import PngInfo from PIL.PngImagePlugin import PngInfo
from aiohttp import web from aiohttp import web
from can_ada import URL, parse as urlparse from can_ada import URL, parse as urlparse
from pkg_resources import resource_filename
from typing_extensions import NamedTuple from typing_extensions import NamedTuple
import comfy.interruption import comfy.interruption
@ -38,6 +37,7 @@ from ..cmd import folder_paths
from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue from ..component_model.abstract_prompt_queue import AbstractPromptQueue, AsyncAbstractPromptQueue
from ..component_model.executor_types import ExecutorToClientProgress from ..component_model.executor_types import ExecutorToClientProgress
from ..component_model.file_output_path import file_output_path from ..component_model.file_output_path import file_output_path
from ..component_model.files import get_package_as_path
from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \ from ..component_model.queue_types import QueueItem, HistoryEntry, BinaryEventTypes, TaskInvocation, ExecutionError, \
ExecutionStatus ExecutionStatus
from ..digest import digest from ..digest import digest
@ -115,7 +115,7 @@ class PromptServer(ExecutorToClientProgress):
self.sockets = dict() self.sockets = dict()
web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../web") web_root_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "../web")
if not os.path.exists(web_root_path): if not os.path.exists(web_root_path):
web_root_path = resource_filename('comfy', 'web/') web_root_path = get_package_as_path('comfy', 'web/')
self.web_root = web_root_path self.web_root = web_root_path
routes = web.RouteTableDef() routes = web.RouteTableDef()
self.routes: web.RouteTableDef = routes self.routes: web.RouteTableDef = routes

View File

@ -0,0 +1,52 @@
from __future__ import annotations
import json
import os
from importlib import resources as resources
from typing import Optional
def get_path_as_dict(config_dict_or_path: str | dict | None, config_path_inside_package: str, package: str = 'comfy') -> dict:
"""
Given a package and a filename inside the package, returns it as a JSON dict; or, returns the file pointed to by
config_dict_or_path, when it is not None and when it exists
:param config_dict_or_path: a file path or dict pointing to a JSON file. If it exists, it is parsed and returned. Otherwise, when None, falls back to other defaults
:param config_path_inside_package: a filename inside a package
:param package: a package containing the file
:return:
"""
config: dict | None = None
if config_dict_or_path is None:
config_dict_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), config_path_inside_package)
if isinstance(config_dict_or_path, str):
if config_dict_or_path.startswith("{"):
config = json.loads(config_dict_or_path)
else:
if not os.path.exists(config_dict_or_path):
with resources.as_file(resources.files(package) / config_path_inside_package) as config_path:
with open(config_path) as f:
config = json.load(f)
else:
with open(config_dict_or_path) as f:
config = json.load(f)
elif isinstance(config_dict_or_path, dict):
config = config_dict_or_path
assert config is not None
return config
def get_package_as_path(package: str, subdir: Optional[str] = None) -> str:
"""
Gets the path on the file system to a package. This unpacks it completely.
:param package: the package containing the files
:param subdir: if specified, a subdirectory containing files (and not python packages), such as a web/ directory inside a package
:return:
"""
traversable = resources.files(package)
if subdir is not None:
traversable = traversable / subdir
return os.path.commonpath(list(map(str, traversable.iterdir())))

View File

@ -167,7 +167,6 @@ KNOWN_CHECKPOINTS = [
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"), CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"), CivitFile(112902, 351306, filename="dreamshaperXL_v21TurboDPMSDE.safetensors"),
CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"), CivitFile(139562, 344487, filename="realvisxlV40_v40Bakedvae.safetensors"),
] ]
KNOWN_UNCLIP_CHECKPOINTS = [ KNOWN_UNCLIP_CHECKPOINTS = [
@ -304,6 +303,10 @@ KNOWN_HUGGINGFACE_MODEL_REPOS = {
"microsoft/Phi-3-mini-4k-instruct", "microsoft/Phi-3-mini-4k-instruct",
} }
KNOWN_UNET_MODELS: List[Union[CivitFile | HuggingFile]] = []
KNOWN_CLIP_MODELS: List[Union[CivitFile | HuggingFile]] = []
def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]: def add_known_models(folder_name: str, symbol: List[Union[CivitFile, HuggingFile]], *models: Union[CivitFile, HuggingFile]) -> List[Union[CivitFile, HuggingFile]]:
if args.disable_known_models: if args.disable_known_models:

View File

@ -26,7 +26,7 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview from ..cmd import folder_paths, latent_preview
from ..execution_context import current_execution_context from ..execution_context import current_execution_context
from ..images import open_image from ..images import open_image
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
from ..nodes.common import MAX_RESOLUTION from ..nodes.common import MAX_RESOLUTION
from .. import controlnet from .. import controlnet
from ..open_exr import load_exr from ..open_exr import load_exr
@ -799,7 +799,7 @@ class ControlNetApplyAdvanced:
class UNETLoader: class UNETLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("unet"),), return {"required": { "unet_name": (get_filename_list_with_downloadable("unet", KNOWN_UNET_MODELS),),
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet" FUNCTION = "load_unet"
@ -807,14 +807,14 @@ class UNETLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_unet(self, unet_name): def load_unet(self, unet_name):
unet_path = folder_paths.get_full_path("unet", unet_name) unet_path = get_or_download("unet", unet_name, KNOWN_UNET_MODELS)
model = sd.load_unet(unet_path) model = sd.load_unet(unet_path)
return (model,) return (model,)
class CLIPLoader: class CLIPLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip"),), return {"required": { "clip_name": (get_filename_list_with_downloadable("clip", KNOWN_CLIP_MODELS),),
"type": (["stable_diffusion", "stable_cascade"], ), "type": (["stable_diffusion", "stable_cascade"], ),
}} }}
RETURN_TYPES = ("CLIP",) RETURN_TYPES = ("CLIP",)
@ -823,11 +823,14 @@ class CLIPLoader:
CATEGORY = "advanced/loaders" CATEGORY = "advanced/loaders"
def load_clip(self, clip_name, type="stable_diffusion"): def load_clip(self, clip_name, type="stable_diffusion"):
clip_type = sd.CLIPType.STABLE_DIFFUSION if type == "stable_diffusion":
if type == "stable_cascade": clip_type = sd.CLIPType.STABLE_DIFFUSION
elif type == "stable_cascade":
clip_type = sd.CLIPType.STABLE_CASCADE clip_type = sd.CLIPType.STABLE_CASCADE
else:
logging.warning(f"Unknown clip type argument passed: {type} for model {clip_name}")
clip_path = folder_paths.get_full_path("clip", clip_name) clip_path = get_or_download("clip", clip_name, KNOWN_CLIP_MODELS)
clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type) clip = sd.load_clip(ckpt_paths=[clip_path], embedding_directory=folder_paths.get_folder_paths("embeddings"), clip_type=clip_type)
return (clip,) return (clip,)

View File

@ -10,10 +10,10 @@ from functools import reduce
from importlib.metadata import entry_points from importlib.metadata import entry_points
from opentelemetry.trace import Span, Status, StatusCode from opentelemetry.trace import Span, Status, StatusCode
from pkg_resources import resource_filename
from .package_typing import ExportedNodes from .package_typing import ExportedNodes
from ..cmd.main_pre import tracer from ..cmd.main_pre import tracer
from ..component_model.files import get_package_as_path
_comfy_nodes: ExportedNodes = ExportedNodes() _comfy_nodes: ExportedNodes = ExportedNodes()
@ -28,7 +28,7 @@ def _import_nodes_in_module(exported_nodes: ExportedNodes, module: types.ModuleT
exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names) exported_nodes.NODE_DISPLAY_NAME_MAPPINGS.update(node_display_names)
if web_directory: if web_directory:
# load the extension resources path # load the extension resources path
abs_web_directory = os.path.abspath(resource_filename(module.__name__, web_directory)) abs_web_directory = os.path.abspath(get_package_as_path(module.__name__, web_directory))
if not os.path.isdir(abs_web_directory): if not os.path.isdir(abs_web_directory):
abs_web_directory = os.path.abspath(os.path.join(os.path.dirname(module.__file__), web_directory)) abs_web_directory = os.path.abspath(os.path.join(os.path.dirname(module.__file__), web_directory))
if not os.path.isdir(abs_web_directory): if not os.path.isdir(abs_web_directory):

View File

@ -1,8 +1,6 @@
from __future__ import annotations from __future__ import annotations
import copy import copy
import importlib.resources as resources
import json
import logging import logging
import os import os
import traceback import traceback
@ -10,36 +8,12 @@ import zipfile
from typing import Tuple, Sequence, TypeVar from typing import Tuple, Sequence, TypeVar
import torch import torch
from pkg_resources import resource_filename
from transformers import CLIPTokenizer, PreTrainedTokenizerBase from transformers import CLIPTokenizer, PreTrainedTokenizerBase
from . import clip_model from . import clip_model
from . import model_management from . import model_management
from . import ops from . import ops
from .component_model.files import get_path_as_dict, get_package_as_path
def get_clip_config_dict(text_model_config_or_path: str | dict | None, text_model_config_path_in_comfy: str, package: str = 'comfy') -> dict:
config: dict | None = None
if text_model_config_or_path is None:
text_model_config_or_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), text_model_config_path_in_comfy)
if isinstance(text_model_config_or_path, str):
if text_model_config_or_path.startswith("{"):
config = json.loads(text_model_config_or_path)
else:
if not os.path.exists(text_model_config_or_path):
with resources.as_file(resources.files(package) / text_model_config_path_in_comfy) as config_path:
with open(config_path) as f:
config = json.load(f)
else:
with open(text_model_config_or_path) as f:
config = json.load(f)
elif isinstance(text_model_config_or_path, dict):
config = text_model_config_or_path
assert config is not None
return config
def gen_empty_tokens(special_tokens, length): def gen_empty_tokens(special_tokens, length):
@ -109,7 +83,7 @@ class SDClipModel(torch.nn.Module, ClipTokenWeightEncoder):
special_tokens = {"start": 49406, "end": 49407, "pad": 49407} special_tokens = {"start": 49406, "end": 49407, "pad": 49407}
assert layer in self.LAYERS assert layer in self.LAYERS
config = get_clip_config_dict(textmodel_json_config, "sd1_clip_config.json") config = get_path_as_dict(textmodel_json_config, "sd1_clip_config.json")
self.transformer = model_class(config, dtype, device, ops.manual_cast) self.transformer = model_class(config, dtype, device, ops.manual_cast)
self.num_layers = self.transformer.num_layers self.num_layers = self.transformer.num_layers
@ -402,7 +376,7 @@ class SDTokenizer:
tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer") tokenizer_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), "sd1_tokenizer")
if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")): if not os.path.exists(os.path.join(tokenizer_path, "tokenizer_config.json")):
# package based # package based
tokenizer_path = resource_filename('comfy', 'sd1_tokenizer/') tokenizer_path = get_package_as_path('comfy.sd1_tokenizer')
self.tokenizer_class = tokenizer_class self.tokenizer_class = tokenizer_class
self.tokenizer_path = tokenizer_path self.tokenizer_path = tokenizer_path
self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path) self.tokenizer: PreTrainedTokenizerBase = tokenizer_class.from_pretrained(tokenizer_path)

View File

@ -1,6 +1,6 @@
from . import sd1_clip from . import sd1_clip
from .sd1_clip import get_clip_config_dict from .component_model.files import get_path_as_dict
class SD2ClipHModel(sd1_clip.SDClipModel): class SD2ClipHModel(sd1_clip.SDClipModel):
@ -9,7 +9,7 @@ class SD2ClipHModel(sd1_clip.SDClipModel):
layer = "hidden" layer = "hidden"
layer_idx = -2 layer_idx = -2
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "sd2_clip_config.json") textmodel_json_config = get_path_as_dict(textmodel_json_config, "sd2_clip_config.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0}) super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, special_tokens={"start": 49406, "end": 49407, "pad": 0})

View File

@ -3,7 +3,7 @@ import copy
import torch import torch
from . import sd1_clip from . import sd1_clip
from .sd1_clip import get_clip_config_dict from .component_model.files import get_path_as_dict
class SDXLClipG(sd1_clip.SDClipModel): class SDXLClipG(sd1_clip.SDClipModel):
@ -12,7 +12,7 @@ class SDXLClipG(sd1_clip.SDClipModel):
layer = "hidden" layer = "hidden"
layer_idx = -2 layer_idx = -2
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json") textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False) special_tokens={"start": 49406, "end": 49407, "pad": 0}, layer_norm_hidden_state=False)
@ -91,7 +91,7 @@ class StableCascadeTokenizer(sd1_clip.SD1Tokenizer):
class StableCascadeClipG(sd1_clip.SDClipModel): class StableCascadeClipG(sd1_clip.SDClipModel):
def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None): def __init__(self, device="cpu", max_length=77, freeze=True, layer="hidden", layer_idx=-1, dtype=None, textmodel_json_config=None):
textmodel_json_config = get_clip_config_dict(textmodel_json_config, "clip_config_bigg.json") textmodel_json_config = get_path_as_dict(textmodel_json_config, "clip_config_bigg.json")
super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype, super().__init__(device=device, freeze=freeze, layer=layer, layer_idx=layer_idx, textmodel_json_config=textmodel_json_config, dtype=dtype,
special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True) special_tokens={"start": 49406, "end": 49407, "pad": 49407}, layer_norm_hidden_state=False, enable_attention_masks=True)

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import contextlib import contextlib
import logging import logging
import math import math
import os
import os.path import os.path
import random import random
import struct import struct

View File

@ -1,6 +1,6 @@
from __future__ import annotations from __future__ import annotations
from typing import Any, List, Dict from typing import Any, Dict, Optional
import torch import torch
from fastchat.model import get_conversation_template from fastchat.model import get_conversation_template
@ -31,16 +31,20 @@ class TransformersLoader(CustomNode):
def INPUT_TYPES(cls) -> InputTypes: def INPUT_TYPES(cls) -> InputTypes:
return { return {
"required": { "required": {
"ckpt_name": (huggingface_repos(),) "ckpt_name": (huggingface_repos(),),
"subfolder": ("STRING", {})
} }
} }
RETURN_TYPES = "MODEL", RETURN_TYPES = "MODEL",
FUNCTION = "execute" FUNCTION = "execute"
def execute(self, ckpt_name: str): def execute(self, ckpt_name: str, subfolder: Optional[str] = None):
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
with comfy_tqdm(): with comfy_tqdm():
model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
tokenizer = AutoTokenizer.from_pretrained(ckpt_name) tokenizer = AutoTokenizer.from_pretrained(ckpt_name)
model_managed = TransformersManagedModel(ckpt_name, model, tokenizer) model_managed = TransformersManagedModel(ckpt_name, model, tokenizer)
return model_managed, return model_managed,

View File

@ -215,7 +215,7 @@ class StringPosixPathJoin(CustomNode):
def INPUT_TYPES(cls) -> InputTypes: def INPUT_TYPES(cls) -> InputTypes:
return { return {
"required": { "required": {
f"value{i}": ("STRING", {"default": "", "multiline": False}) for i in range(5) f"value{i}": ("STRING", {"default": "", "multiline": False, "forceInput": True}) for i in range(5)
} }
} }
@ -284,7 +284,7 @@ class DevNullUris(CustomNode):
class StringJoin(CustomNode): class StringJoin(CustomNode):
@classmethod @classmethod
def INPUT_TYPES(cls) -> InputTypes: def INPUT_TYPES(cls) -> InputTypes:
required = {f"value{i}": ("STRING", {"default": "", "multiline": True}) for i in range(5)} required = {f"value{i}": ("STRING", {"default": "", "multiline": True, "forceInput": True}) for i in range(5)}
required["separator"] = ("STRING", {"default": "_"}) required["separator"] = ("STRING", {"default": "_"})
return { return {
"required": required "required": required
@ -304,7 +304,7 @@ class StringToUri(CustomNode):
def INPUT_TYPES(cls) -> InputTypes: def INPUT_TYPES(cls) -> InputTypes:
return { return {
"required": { "required": {
"value": ("STRING", {"default": "", "multiline": True}), "value": ("STRING", {"default": "", "multiline": True, "forceInput": True}),
"batch": ("INT", {"default": 1}) "batch": ("INT", {"default": 1})
} }
} }