mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Improve model downloading, add FolderPaths object for custom nodes
This commit is contained in:
parent
d98c2c5456
commit
5459cfa832
@ -210,12 +210,13 @@ def exists_annotated_filepath(name):
|
|||||||
return os.path.exists(filepath)
|
return os.path.exists(filepath)
|
||||||
|
|
||||||
|
|
||||||
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None) -> str:
|
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, extensions: Optional[set[str]] = None) -> str:
|
||||||
"""
|
"""
|
||||||
Registers a model path for the given canonical name.
|
Registers a model path for the given canonical name.
|
||||||
:param folder_name: the folder name
|
:param folder_name: the folder name
|
||||||
:param full_folder_path: When none, defaults to os.path.join(models_dir, folder_name) aka the folder as
|
:param full_folder_path: When none, defaults to os.path.join(models_dir, folder_name) aka the folder as
|
||||||
a subpath to the default models directory
|
a subpath to the default models directory
|
||||||
|
:param extensions: supported file extensions
|
||||||
:return: the folder path
|
:return: the folder path
|
||||||
"""
|
"""
|
||||||
global folder_names_and_paths
|
global folder_names_and_paths
|
||||||
@ -226,6 +227,9 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None) -
|
|||||||
if full_folder_path not in folder_path.paths:
|
if full_folder_path not in folder_path.paths:
|
||||||
folder_path.paths.append(full_folder_path)
|
folder_path.paths.append(full_folder_path)
|
||||||
|
|
||||||
|
if extensions is not None:
|
||||||
|
folder_path.supported_extensions |= extensions
|
||||||
|
|
||||||
invalidate_cache(folder_name)
|
invalidate_cache(folder_name)
|
||||||
return full_folder_path
|
return full_folder_path
|
||||||
|
|
||||||
@ -270,7 +274,7 @@ def filter_files_extensions(files, extensions):
|
|||||||
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
|
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
|
||||||
|
|
||||||
|
|
||||||
def get_full_path(folder_name, filename):
|
def get_full_path(folder_name, filename) -> Optional[str | bytes | os.PathLike]:
|
||||||
"""
|
"""
|
||||||
Gets the path to a filename inside a folder.
|
Gets the path to a filename inside a folder.
|
||||||
|
|
||||||
|
|||||||
@ -1,3 +1,6 @@
|
|||||||
|
from typing import Final
|
||||||
|
|
||||||
|
|
||||||
def append_dims(x, target_dims):
|
def append_dims(x, target_dims):
|
||||||
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
||||||
dims_to_append = target_dims - x.ndim
|
dims_to_append = target_dims - x.ndim
|
||||||
@ -7,3 +10,7 @@ def append_dims(x, target_dims):
|
|||||||
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
|
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
|
||||||
# https://github.com/pytorch/pytorch/issues/84364
|
# https://github.com/pytorch/pytorch/issues/84364
|
||||||
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
|
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
|
||||||
|
|
||||||
|
|
||||||
|
class FolderOfImages:
|
||||||
|
IMG_EXTENSIONS: Final[set[str]] = frozenset({'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'})
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import collections
|
||||||
import logging
|
import logging
|
||||||
import operator
|
import operator
|
||||||
import os
|
import os
|
||||||
@ -7,7 +8,7 @@ from functools import reduce
|
|||||||
from itertools import chain
|
from itertools import chain
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Any, Optional, Sequence, Final, Set
|
from typing import List, Any, Optional, Sequence, Final, Set, MutableSequence
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
|
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
|
||||||
@ -18,22 +19,29 @@ from safetensors.torch import save_file
|
|||||||
|
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
from .cmd import folder_paths
|
from .cmd import folder_paths
|
||||||
|
from .cmd.folder_paths import add_model_folder_path, supported_pt_extensions
|
||||||
from .component_model.deprecation import _deprecate_method
|
from .component_model.deprecation import _deprecate_method
|
||||||
from .interruption import InterruptProcessingException
|
from .interruption import InterruptProcessingException
|
||||||
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_, Downloadable
|
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_, Downloadable, UrlFile
|
||||||
from .utils import ProgressBar, comfy_tqdm
|
from .utils import ProgressBar, comfy_tqdm
|
||||||
|
|
||||||
_session = Session()
|
_session = Session()
|
||||||
_hf_fs = HfFileSystem()
|
_hf_fs = HfFileSystem()
|
||||||
|
|
||||||
|
|
||||||
def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]:
|
def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Any]] = None) -> List[str]:
|
||||||
|
if known_files is None:
|
||||||
|
known_files = _get_known_models_for_folder_name(folder_name)
|
||||||
|
|
||||||
existing = frozenset(folder_paths.get_filename_list(folder_name))
|
existing = frozenset(folder_paths.get_filename_list(folder_name))
|
||||||
downloadable = frozenset() if args.disable_known_models else frozenset(str(f) for f in known_files)
|
downloadable = frozenset() if args.disable_known_models else frozenset(str(f) for f in known_files)
|
||||||
return sorted(list(existing | downloadable))
|
return sorted(list(existing | downloadable))
|
||||||
|
|
||||||
|
|
||||||
def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> Optional[str]:
|
def get_or_download(folder_name: str, filename: str, known_files: Optional[List[Downloadable]] = None) -> Optional[str]:
|
||||||
|
if known_files is None:
|
||||||
|
known_files = _get_known_models_for_folder_name(folder_name)
|
||||||
|
|
||||||
path = folder_paths.get_full_path(folder_name, filename)
|
path = folder_paths.get_full_path(folder_name, filename)
|
||||||
|
|
||||||
if path is None and not args.disable_known_models:
|
if path is None and not args.disable_known_models:
|
||||||
@ -120,6 +128,8 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
|
|||||||
if civit_file['name'] == filename:
|
if civit_file['name'] == filename:
|
||||||
url = civit_file['downloadUrl']
|
url = civit_file['downloadUrl']
|
||||||
break
|
break
|
||||||
|
elif isinstance(known_file, UrlFile):
|
||||||
|
url = known_file.url
|
||||||
else:
|
else:
|
||||||
raise RuntimeError("unknown file type")
|
raise RuntimeError("unknown file type")
|
||||||
|
|
||||||
@ -159,7 +169,22 @@ Visit the repository, accept the terms, and then do one of the following:
|
|||||||
return path
|
return path
|
||||||
|
|
||||||
|
|
||||||
KNOWN_CHECKPOINTS: Final[List[Downloadable]] = [
|
class KnownDownloadables(collections.UserList[Downloadable]):
|
||||||
|
def __init__(self, data, folder_name: Optional[str] = None):
|
||||||
|
# this should be a view
|
||||||
|
self.data = data
|
||||||
|
self._folder_name = folder_name
|
||||||
|
|
||||||
|
@property
|
||||||
|
def folder_name(self) -> str:
|
||||||
|
return self._folder_name
|
||||||
|
|
||||||
|
@folder_name.setter
|
||||||
|
def folder_name(self, value: str):
|
||||||
|
self._folder_name = value
|
||||||
|
|
||||||
|
|
||||||
|
KNOWN_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors"),
|
HuggingFile("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors"),
|
||||||
HuggingFile("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors"),
|
HuggingFile("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors"),
|
||||||
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors"),
|
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors"),
|
||||||
@ -194,40 +219,40 @@ KNOWN_CHECKPOINTS: Final[List[Downloadable]] = [
|
|||||||
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips.safetensors"),
|
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips.safetensors"),
|
||||||
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"),
|
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"),
|
||||||
HuggingFile("fal/AuraFlow", filename="aura_flow_0.1.safetensors"),
|
HuggingFile("fal/AuraFlow", filename="aura_flow_0.1.safetensors"),
|
||||||
]
|
], folder_name="checkpoints")
|
||||||
|
|
||||||
KNOWN_UNCLIP_CHECKPOINTS: Final[List[Downloadable]] = [
|
KNOWN_UNCLIP_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
|
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
|
||||||
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-h.ckpt"),
|
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-h.ckpt"),
|
||||||
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-l.ckpt"),
|
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-l.ckpt"),
|
||||||
]
|
], folder_name="checkpoints")
|
||||||
|
|
||||||
KNOWN_IMAGE_ONLY_CHECKPOINTS: Final[List[Downloadable]] = [
|
KNOWN_IMAGE_ONLY_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("stabilityai/stable-zero123", "stable_zero123.ckpt")
|
HuggingFile("stabilityai/stable-zero123", "stable_zero123.ckpt")
|
||||||
]
|
], folder_name="checkpoints")
|
||||||
|
|
||||||
KNOWN_UPSCALERS: Final[List[Downloadable]] = [
|
KNOWN_UPSCALERS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("lllyasviel/Annotators", "RealESRGAN_x4plus.pth")
|
HuggingFile("lllyasviel/Annotators", "RealESRGAN_x4plus.pth")
|
||||||
]
|
], folder_name="upscale_models")
|
||||||
|
|
||||||
KNOWN_GLIGEN_MODELS: Final[List[Downloadable]] = [
|
KNOWN_GLIGEN_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned.safetensors", show_in_ui=False),
|
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned.safetensors", show_in_ui=False),
|
||||||
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned_fp16.safetensors"),
|
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned_fp16.safetensors"),
|
||||||
]
|
], folder_name="gligen")
|
||||||
|
|
||||||
KNOWN_CLIP_VISION_MODELS: Final[List[Downloadable]] = [
|
KNOWN_CLIP_VISION_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("comfyanonymous/clip_vision_g", "clip_vision_g.safetensors")
|
HuggingFile("comfyanonymous/clip_vision_g", "clip_vision_g.safetensors")
|
||||||
]
|
], folder_name="clip_vision")
|
||||||
|
|
||||||
KNOWN_LORAS: Final[List[Downloadable]] = [
|
KNOWN_LORAS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"),
|
CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"),
|
||||||
CivitFile(model_id=324815, model_version_id=364137, filename="blur_control_xl_v1.safetensors"),
|
CivitFile(model_id=324815, model_version_id=364137, filename="blur_control_xl_v1.safetensors"),
|
||||||
CivitFile(model_id=47085, model_version_id=55199, filename="GoodHands-beta2.safetensors"),
|
CivitFile(model_id=47085, model_version_id=55199, filename="GoodHands-beta2.safetensors"),
|
||||||
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-12steps-CFG-lora.safetensors"),
|
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-12steps-CFG-lora.safetensors"),
|
||||||
HuggingFile("ByteDance/Hyper-SD", "Hyper-SD15-12steps-CFG-lora.safetensors"),
|
HuggingFile("ByteDance/Hyper-SD", "Hyper-SD15-12steps-CFG-lora.safetensors"),
|
||||||
]
|
], folder_name="loras")
|
||||||
|
|
||||||
KNOWN_CONTROLNETS: Final[List[Downloadable]] = [
|
KNOWN_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "OpenPoseXL2.safetensors", convert_to_16_bit=True, size=2502139104),
|
HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "OpenPoseXL2.safetensors", convert_to_16_bit=True, size=2502139104),
|
||||||
HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "control-lora-openposeXL2-rank256.safetensors"),
|
HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "control-lora-openposeXL2-rank256.safetensors"),
|
||||||
HuggingFile("comfyanonymous/ControlNet-v1-1_fp16_safetensors", "control_lora_rank128_v11e_sd15_ip2p_fp16.safetensors"),
|
HuggingFile("comfyanonymous/ControlNet-v1-1_fp16_safetensors", "control_lora_rank128_v11e_sd15_ip2p_fp16.safetensors"),
|
||||||
@ -317,9 +342,9 @@ KNOWN_CONTROLNETS: Final[List[Downloadable]] = [
|
|||||||
HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"),
|
HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"),
|
||||||
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"),
|
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"),
|
||||||
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"),
|
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"),
|
||||||
]
|
], folder_name="controlnet")
|
||||||
|
|
||||||
KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [
|
KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_canny_fp16.safetensors"),
|
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_canny_fp16.safetensors"),
|
||||||
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_depth_fp16.safetensors"),
|
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_depth_fp16.safetensors"),
|
||||||
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_hed_fp16.safetensors"),
|
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_hed_fp16.safetensors"),
|
||||||
@ -328,39 +353,82 @@ KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [
|
|||||||
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_openpose_fp16.safetensors"),
|
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_openpose_fp16.safetensors"),
|
||||||
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_scribble_fp16.safetensors"),
|
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_scribble_fp16.safetensors"),
|
||||||
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"),
|
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"),
|
||||||
]
|
], folder_name="controlnet")
|
||||||
|
|
||||||
KNOWN_APPROX_VAES: Final[List[Downloadable]] = [
|
KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
|
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
|
||||||
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"),
|
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"),
|
||||||
]
|
], folder_name="vae_approx")
|
||||||
|
|
||||||
KNOWN_VAES: Final[List[Downloadable]] = [
|
KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("stabilityai/sdxl-vae", "sdxl_vae.safetensors"),
|
HuggingFile("stabilityai/sdxl-vae", "sdxl_vae.safetensors"),
|
||||||
HuggingFile("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors"),
|
HuggingFile("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors"),
|
||||||
]
|
], folder_name="vae")
|
||||||
|
|
||||||
KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
|
KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
|
||||||
"JingyeChen22/textdiffuser2_layout_planner",
|
'JingyeChen22/textdiffuser2_layout_planner',
|
||||||
'JingyeChen22/textdiffuser2-full-ft',
|
'JingyeChen22/textdiffuser2-full-ft',
|
||||||
"microsoft/Phi-3-mini-4k-instruct",
|
'microsoft/Phi-3-mini-4k-instruct',
|
||||||
"llava-hf/llava-v1.6-mistral-7b-hf"
|
'llava-hf/llava-v1.6-mistral-7b-hf'
|
||||||
}
|
}
|
||||||
|
|
||||||
KNOWN_UNET_MODELS: Final[List[Downloadable]] = [
|
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors")
|
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors")
|
||||||
]
|
], folder_name="unet")
|
||||||
|
|
||||||
KNOWN_CLIP_MODELS: Final[List[Downloadable]] = [
|
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([
|
||||||
# todo: is this correct?
|
# todo: is this correct?
|
||||||
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp16.safetensors", save_with_filename="t5xxl_fp16.safetensors"),
|
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp16.safetensors", save_with_filename="t5xxl_fp16.safetensors"),
|
||||||
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp8_e4m3fn.safetensors", save_with_filename="t5xxl_fp8_e4m3fn.safetensors"),
|
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp8_e4m3fn.safetensors", save_with_filename="t5xxl_fp8_e4m3fn.safetensors"),
|
||||||
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_g.safetensors", save_with_filename="clip_g.safetensors"),
|
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_g.safetensors", save_with_filename="clip_g.safetensors"),
|
||||||
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_l.safetensors", save_with_filename="clip_l.safetensors"),
|
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_l.safetensors", save_with_filename="clip_l.safetensors"),
|
||||||
|
], folder_name="clip")
|
||||||
|
|
||||||
|
_known_models_db: list[KnownDownloadables] = [
|
||||||
|
KNOWN_CHECKPOINTS,
|
||||||
|
KNOWN_VAES,
|
||||||
|
KNOWN_LORAS,
|
||||||
|
KNOWN_UNET_MODELS,
|
||||||
|
KNOWN_APPROX_VAES,
|
||||||
|
KNOWN_DIFF_CONTROLNETS,
|
||||||
|
KNOWN_CLIP_MODELS,
|
||||||
|
KNOWN_CLIP_VISION_MODELS,
|
||||||
|
KNOWN_CONTROLNETS,
|
||||||
|
KNOWN_GLIGEN_MODELS,
|
||||||
|
KNOWN_IMAGE_ONLY_CHECKPOINTS,
|
||||||
|
KNOWN_UNCLIP_CHECKPOINTS,
|
||||||
|
KNOWN_UPSCALERS,
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def add_known_models(folder_name: str, known_models: List[Downloadable], *models: Downloadable) -> List[Downloadable]:
|
def _is_known_model_in_models_db(obj: list[Downloadable] | KnownDownloadables):
|
||||||
|
return any(candidate is obj or candidate.data is obj for candidate in _known_models_db)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_known_models_for_folder_name(folder_name: str) -> List[Downloadable]:
|
||||||
|
return list(chain.from_iterable([candidate for candidate in _known_models_db if candidate.folder_name == folder_name]))
|
||||||
|
|
||||||
|
|
||||||
|
def add_known_models(folder_name: str, known_models: Optional[List[Downloadable]] | Downloadable = None, *models: Downloadable) -> MutableSequence[Downloadable]:
|
||||||
|
if isinstance(known_models, Downloadable):
|
||||||
|
models = [known_models] + list(models) or []
|
||||||
|
known_models = None
|
||||||
|
|
||||||
|
if known_models is None:
|
||||||
|
try:
|
||||||
|
known_models = next(candidate for candidate in _known_models_db if candidate.folder_name == folder_name)
|
||||||
|
except StopIteration:
|
||||||
|
add_model_folder_path(folder_name, extensions=supported_pt_extensions)
|
||||||
|
known_models = KnownDownloadables([], folder_name=folder_name)
|
||||||
|
|
||||||
|
# check if any of the pre-existing known models already reference this list
|
||||||
|
if not _is_known_model_in_models_db(known_models):
|
||||||
|
if not isinstance(known_models, KnownDownloadables):
|
||||||
|
# wrap it
|
||||||
|
known_models = KnownDownloadables(known_models)
|
||||||
|
# meets protocol at this point
|
||||||
|
_known_models_db.append(known_models)
|
||||||
|
|
||||||
if len(models) < 1:
|
if len(models) < 1:
|
||||||
return known_models
|
return known_models
|
||||||
|
|
||||||
@ -368,7 +436,7 @@ def add_known_models(folder_name: str, known_models: List[Downloadable], *models
|
|||||||
logging.warning(f"Known models have been disabled in the options (while adding {folder_name}/{','.join(map(str, models))})")
|
logging.warning(f"Known models have been disabled in the options (while adding {folder_name}/{','.join(map(str, models))})")
|
||||||
|
|
||||||
pre_existing = frozenset(known_models)
|
pre_existing = frozenset(known_models)
|
||||||
known_models += [model for model in models if model not in pre_existing]
|
known_models.extend([model for model in models if model not in pre_existing])
|
||||||
folder_paths.invalidate_cache(folder_name)
|
folder_paths.invalidate_cache(folder_name)
|
||||||
return known_models
|
return known_models
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,40 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import functools
|
||||||
from os.path import split
|
from os.path import split
|
||||||
|
from pathlib import PurePosixPath
|
||||||
from typing import Optional, List, Sequence, Union
|
from typing import Optional, List, Sequence, Union
|
||||||
|
|
||||||
|
from can_ada import parse, URL
|
||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass(frozen=True)
|
||||||
|
class UrlFile:
|
||||||
|
_url: str
|
||||||
|
_save_with_filename: Optional[str] = None
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return self.save_with_filename
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def url(self) -> str:
|
||||||
|
return self._url
|
||||||
|
|
||||||
|
@functools.cached_property
|
||||||
|
def parsed_url(self) -> URL:
|
||||||
|
return parse(self._url)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def save_with_filename(self) -> str:
|
||||||
|
return self._save_with_filename or self.filename
|
||||||
|
|
||||||
|
@property
|
||||||
|
def filename(self) -> str:
|
||||||
|
return PurePosixPath(self.parsed_url.pathname).name
|
||||||
|
|
||||||
|
|
||||||
@dataclasses.dataclass(frozen=True)
|
@dataclasses.dataclass(frozen=True)
|
||||||
class CivitFile:
|
class CivitFile:
|
||||||
"""
|
"""
|
||||||
@ -154,4 +182,4 @@ class CivitModelsGetResponse(TypedDict):
|
|||||||
modelVersions: List[CivitModelVersion]
|
modelVersions: List[CivitModelVersion]
|
||||||
|
|
||||||
|
|
||||||
Downloadable = Union[CivitFile | HuggingFile]
|
Downloadable = Union[CivitFile | HuggingFile | UrlFile]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user