Improve model downloading, add FolderPaths object for custom nodes

This commit is contained in:
doctorpangloss 2024-07-17 17:06:49 -07:00
parent d98c2c5456
commit 5459cfa832
4 changed files with 144 additions and 37 deletions

View File

@ -210,12 +210,13 @@ def exists_annotated_filepath(name):
return os.path.exists(filepath)
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None) -> str:
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, extensions: Optional[set[str]] = None) -> str:
"""
Registers a model path for the given canonical name.
:param folder_name: the folder name
:param full_folder_path: When none, defaults to os.path.join(models_dir, folder_name) aka the folder as
a subpath to the default models directory
:param extensions: supported file extensions
:return: the folder path
"""
global folder_names_and_paths
@ -226,6 +227,9 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None) -
if full_folder_path not in folder_path.paths:
folder_path.paths.append(full_folder_path)
if extensions is not None:
folder_path.supported_extensions |= extensions
invalidate_cache(folder_name)
return full_folder_path
@ -270,7 +274,7 @@ def filter_files_extensions(files, extensions):
return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files)))
def get_full_path(folder_name, filename):
def get_full_path(folder_name, filename) -> Optional[str | bytes | os.PathLike]:
"""
Gets the path to a filename inside a folder.

View File

@ -1,3 +1,6 @@
from typing import Final
def append_dims(x, target_dims):
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
dims_to_append = target_dims - x.ndim
@ -7,3 +10,7 @@ def append_dims(x, target_dims):
# MPS will get inf values if it tries to index into the new axes, but detaching fixes this.
# https://github.com/pytorch/pytorch/issues/84364
return expanded.detach().clone() if expanded.device.type == 'mps' else expanded
class FolderOfImages:
IMG_EXTENSIONS: Final[set[str]] = frozenset({'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'})

View File

@ -1,5 +1,6 @@
from __future__ import annotations
import collections
import logging
import operator
import os
@ -7,7 +8,7 @@ from functools import reduce
from itertools import chain
from os.path import join
from pathlib import Path
from typing import List, Any, Optional, Sequence, Final, Set
from typing import List, Any, Optional, Sequence, Final, Set, MutableSequence
import tqdm
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
@ -18,22 +19,29 @@ from safetensors.torch import save_file
from .cli_args import args
from .cmd import folder_paths
from .cmd.folder_paths import add_model_folder_path, supported_pt_extensions
from .component_model.deprecation import _deprecate_method
from .interruption import InterruptProcessingException
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_, Downloadable
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_, Downloadable, UrlFile
from .utils import ProgressBar, comfy_tqdm
_session = Session()
_hf_fs = HfFileSystem()
def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]:
def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Any]] = None) -> List[str]:
if known_files is None:
known_files = _get_known_models_for_folder_name(folder_name)
existing = frozenset(folder_paths.get_filename_list(folder_name))
downloadable = frozenset() if args.disable_known_models else frozenset(str(f) for f in known_files)
return sorted(list(existing | downloadable))
def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> Optional[str]:
def get_or_download(folder_name: str, filename: str, known_files: Optional[List[Downloadable]] = None) -> Optional[str]:
if known_files is None:
known_files = _get_known_models_for_folder_name(folder_name)
path = folder_paths.get_full_path(folder_name, filename)
if path is None and not args.disable_known_models:
@ -120,6 +128,8 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
if civit_file['name'] == filename:
url = civit_file['downloadUrl']
break
elif isinstance(known_file, UrlFile):
url = known_file.url
else:
raise RuntimeError("unknown file type")
@ -159,7 +169,22 @@ Visit the repository, accept the terms, and then do one of the following:
return path
KNOWN_CHECKPOINTS: Final[List[Downloadable]] = [
class KnownDownloadables(collections.UserList[Downloadable]):
def __init__(self, data, folder_name: Optional[str] = None):
# this should be a view
self.data = data
self._folder_name = folder_name
@property
def folder_name(self) -> str:
return self._folder_name
@folder_name.setter
def folder_name(self, value: str):
self._folder_name = value
KNOWN_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors"),
HuggingFile("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors"),
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors"),
@ -194,40 +219,40 @@ KNOWN_CHECKPOINTS: Final[List[Downloadable]] = [
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"),
HuggingFile("fal/AuraFlow", filename="aura_flow_0.1.safetensors"),
]
], folder_name="checkpoints")
KNOWN_UNCLIP_CHECKPOINTS: Final[List[Downloadable]] = [
KNOWN_UNCLIP_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-h.ckpt"),
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-l.ckpt"),
]
], folder_name="checkpoints")
KNOWN_IMAGE_ONLY_CHECKPOINTS: Final[List[Downloadable]] = [
KNOWN_IMAGE_ONLY_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("stabilityai/stable-zero123", "stable_zero123.ckpt")
]
], folder_name="checkpoints")
KNOWN_UPSCALERS: Final[List[Downloadable]] = [
KNOWN_UPSCALERS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("lllyasviel/Annotators", "RealESRGAN_x4plus.pth")
]
], folder_name="upscale_models")
KNOWN_GLIGEN_MODELS: Final[List[Downloadable]] = [
KNOWN_GLIGEN_MODELS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned.safetensors", show_in_ui=False),
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned_fp16.safetensors"),
]
], folder_name="gligen")
KNOWN_CLIP_VISION_MODELS: Final[List[Downloadable]] = [
KNOWN_CLIP_VISION_MODELS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("comfyanonymous/clip_vision_g", "clip_vision_g.safetensors")
]
], folder_name="clip_vision")
KNOWN_LORAS: Final[List[Downloadable]] = [
KNOWN_LORAS: Final[KnownDownloadables] = KnownDownloadables([
CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"),
CivitFile(model_id=324815, model_version_id=364137, filename="blur_control_xl_v1.safetensors"),
CivitFile(model_id=47085, model_version_id=55199, filename="GoodHands-beta2.safetensors"),
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-12steps-CFG-lora.safetensors"),
HuggingFile("ByteDance/Hyper-SD", "Hyper-SD15-12steps-CFG-lora.safetensors"),
]
], folder_name="loras")
KNOWN_CONTROLNETS: Final[List[Downloadable]] = [
KNOWN_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "OpenPoseXL2.safetensors", convert_to_16_bit=True, size=2502139104),
HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "control-lora-openposeXL2-rank256.safetensors"),
HuggingFile("comfyanonymous/ControlNet-v1-1_fp16_safetensors", "control_lora_rank128_v11e_sd15_ip2p_fp16.safetensors"),
@ -317,9 +342,9 @@ KNOWN_CONTROLNETS: Final[List[Downloadable]] = [
HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"),
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"),
HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"),
]
], folder_name="controlnet")
KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [
KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_canny_fp16.safetensors"),
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_depth_fp16.safetensors"),
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_hed_fp16.safetensors"),
@ -328,39 +353,82 @@ KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_openpose_fp16.safetensors"),
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_scribble_fp16.safetensors"),
HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"),
]
], folder_name="controlnet")
KNOWN_APPROX_VAES: Final[List[Downloadable]] = [
KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"),
HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"),
]
], folder_name="vae_approx")
KNOWN_VAES: Final[List[Downloadable]] = [
KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("stabilityai/sdxl-vae", "sdxl_vae.safetensors"),
HuggingFile("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors"),
]
], folder_name="vae")
KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = {
"JingyeChen22/textdiffuser2_layout_planner",
'JingyeChen22/textdiffuser2_layout_planner',
'JingyeChen22/textdiffuser2-full-ft',
"microsoft/Phi-3-mini-4k-instruct",
"llava-hf/llava-v1.6-mistral-7b-hf"
'microsoft/Phi-3-mini-4k-instruct',
'llava-hf/llava-v1.6-mistral-7b-hf'
}
KNOWN_UNET_MODELS: Final[List[Downloadable]] = [
KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([
HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors")
]
], folder_name="unet")
KNOWN_CLIP_MODELS: Final[List[Downloadable]] = [
KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([
# todo: is this correct?
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp16.safetensors", save_with_filename="t5xxl_fp16.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp8_e4m3fn.safetensors", save_with_filename="t5xxl_fp8_e4m3fn.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_g.safetensors", save_with_filename="clip_g.safetensors"),
HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_l.safetensors", save_with_filename="clip_l.safetensors"),
], folder_name="clip")
_known_models_db: list[KnownDownloadables] = [
KNOWN_CHECKPOINTS,
KNOWN_VAES,
KNOWN_LORAS,
KNOWN_UNET_MODELS,
KNOWN_APPROX_VAES,
KNOWN_DIFF_CONTROLNETS,
KNOWN_CLIP_MODELS,
KNOWN_CLIP_VISION_MODELS,
KNOWN_CONTROLNETS,
KNOWN_GLIGEN_MODELS,
KNOWN_IMAGE_ONLY_CHECKPOINTS,
KNOWN_UNCLIP_CHECKPOINTS,
KNOWN_UPSCALERS,
]
def add_known_models(folder_name: str, known_models: List[Downloadable], *models: Downloadable) -> List[Downloadable]:
def _is_known_model_in_models_db(obj: list[Downloadable] | KnownDownloadables):
return any(candidate is obj or candidate.data is obj for candidate in _known_models_db)
def _get_known_models_for_folder_name(folder_name: str) -> List[Downloadable]:
return list(chain.from_iterable([candidate for candidate in _known_models_db if candidate.folder_name == folder_name]))
def add_known_models(folder_name: str, known_models: Optional[List[Downloadable]] | Downloadable = None, *models: Downloadable) -> MutableSequence[Downloadable]:
if isinstance(known_models, Downloadable):
models = [known_models] + list(models) or []
known_models = None
if known_models is None:
try:
known_models = next(candidate for candidate in _known_models_db if candidate.folder_name == folder_name)
except StopIteration:
add_model_folder_path(folder_name, extensions=supported_pt_extensions)
known_models = KnownDownloadables([], folder_name=folder_name)
# check if any of the pre-existing known models already reference this list
if not _is_known_model_in_models_db(known_models):
if not isinstance(known_models, KnownDownloadables):
# wrap it
known_models = KnownDownloadables(known_models)
# meets protocol at this point
_known_models_db.append(known_models)
if len(models) < 1:
return known_models
@ -368,7 +436,7 @@ def add_known_models(folder_name: str, known_models: List[Downloadable], *models
logging.warning(f"Known models have been disabled in the options (while adding {folder_name}/{','.join(map(str, models))})")
pre_existing = frozenset(known_models)
known_models += [model for model in models if model not in pre_existing]
known_models.extend([model for model in models if model not in pre_existing])
folder_paths.invalidate_cache(folder_name)
return known_models

View File

@ -1,12 +1,40 @@
from __future__ import annotations
import dataclasses
import functools
from os.path import split
from pathlib import PurePosixPath
from typing import Optional, List, Sequence, Union
from can_ada import parse, URL
from typing_extensions import TypedDict, NotRequired
@dataclasses.dataclass(frozen=True)
class UrlFile:
_url: str
_save_with_filename: Optional[str] = None
def __str__(self):
return self.save_with_filename
@functools.cached_property
def url(self) -> str:
return self._url
@functools.cached_property
def parsed_url(self) -> URL:
return parse(self._url)
@property
def save_with_filename(self) -> str:
return self._save_with_filename or self.filename
@property
def filename(self) -> str:
return PurePosixPath(self.parsed_url.pathname).name
@dataclasses.dataclass(frozen=True)
class CivitFile:
"""
@ -154,4 +182,4 @@ class CivitModelsGetResponse(TypedDict):
modelVersions: List[CivitModelVersion]
Downloadable = Union[CivitFile | HuggingFile]
Downloadable = Union[CivitFile | HuggingFile | UrlFile]