From 5459cfa832a1f9f3ede9c7475becf77bc940bac2 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Wed, 17 Jul 2024 17:06:49 -0700 Subject: [PATCH] Improve model downloading, add FolderPaths object for custom nodes --- comfy/cmd/folder_paths.py | 8 +- comfy/k_diffusion/utils.py | 7 ++ comfy/model_downloader.py | 136 ++++++++++++++++++++++++-------- comfy/model_downloader_types.py | 30 ++++++- 4 files changed, 144 insertions(+), 37 deletions(-) diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index 32c5c3c54..207bcf4f1 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -210,12 +210,13 @@ def exists_annotated_filepath(name): return os.path.exists(filepath) -def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None) -> str: +def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None, extensions: Optional[set[str]] = None) -> str: """ Registers a model path for the given canonical name. :param folder_name: the folder name :param full_folder_path: When none, defaults to os.path.join(models_dir, folder_name) aka the folder as a subpath to the default models directory + :param extensions: supported file extensions :return: the folder path """ global folder_names_and_paths @@ -226,6 +227,9 @@ def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None) - if full_folder_path not in folder_path.paths: folder_path.paths.append(full_folder_path) + if extensions is not None: + folder_path.supported_extensions |= extensions + invalidate_cache(folder_name) return full_folder_path @@ -270,7 +274,7 @@ def filter_files_extensions(files, extensions): return sorted(list(filter(lambda a: os.path.splitext(a)[-1].lower() in extensions or len(extensions) == 0, files))) -def get_full_path(folder_name, filename): +def get_full_path(folder_name, filename) -> Optional[str | bytes | os.PathLike]: """ Gets the path to a filename inside a folder. diff --git a/comfy/k_diffusion/utils.py b/comfy/k_diffusion/utils.py index 7d2f0f787..7190f18ec 100644 --- a/comfy/k_diffusion/utils.py +++ b/comfy/k_diffusion/utils.py @@ -1,3 +1,6 @@ +from typing import Final + + def append_dims(x, target_dims): """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" dims_to_append = target_dims - x.ndim @@ -7,3 +10,7 @@ def append_dims(x, target_dims): # MPS will get inf values if it tries to index into the new axes, but detaching fixes this. # https://github.com/pytorch/pytorch/issues/84364 return expanded.detach().clone() if expanded.device.type == 'mps' else expanded + + +class FolderOfImages: + IMG_EXTENSIONS: Final[set[str]] = frozenset({'.jpg', '.jpeg', '.png', '.ppm', '.bmp', '.pgm', '.tif', '.tiff', '.webp'}) diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index eaad659be..69f2972b5 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -1,5 +1,6 @@ from __future__ import annotations +import collections import logging import operator import os @@ -7,7 +8,7 @@ from functools import reduce from itertools import chain from os.path import join from pathlib import Path -from typing import List, Any, Optional, Sequence, Final, Set +from typing import List, Any, Optional, Sequence, Final, Set, MutableSequence import tqdm from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem @@ -18,22 +19,29 @@ from safetensors.torch import save_file from .cli_args import args from .cmd import folder_paths +from .cmd.folder_paths import add_model_folder_path, supported_pt_extensions from .component_model.deprecation import _deprecate_method from .interruption import InterruptProcessingException -from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_, Downloadable +from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_, Downloadable, UrlFile from .utils import ProgressBar, comfy_tqdm _session = Session() _hf_fs = HfFileSystem() -def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]: +def get_filename_list_with_downloadable(folder_name: str, known_files: Optional[List[Any]] = None) -> List[str]: + if known_files is None: + known_files = _get_known_models_for_folder_name(folder_name) + existing = frozenset(folder_paths.get_filename_list(folder_name)) downloadable = frozenset() if args.disable_known_models else frozenset(str(f) for f in known_files) return sorted(list(existing | downloadable)) -def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> Optional[str]: +def get_or_download(folder_name: str, filename: str, known_files: Optional[List[Downloadable]] = None) -> Optional[str]: + if known_files is None: + known_files = _get_known_models_for_folder_name(folder_name) + path = folder_paths.get_full_path(folder_name, filename) if path is None and not args.disable_known_models: @@ -120,6 +128,8 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi if civit_file['name'] == filename: url = civit_file['downloadUrl'] break + elif isinstance(known_file, UrlFile): + url = known_file.url else: raise RuntimeError("unknown file type") @@ -159,7 +169,22 @@ Visit the repository, accept the terms, and then do one of the following: return path -KNOWN_CHECKPOINTS: Final[List[Downloadable]] = [ +class KnownDownloadables(collections.UserList[Downloadable]): + def __init__(self, data, folder_name: Optional[str] = None): + # this should be a view + self.data = data + self._folder_name = folder_name + + @property + def folder_name(self) -> str: + return self._folder_name + + @folder_name.setter + def folder_name(self, value: str): + self._folder_name = value + + +KNOWN_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors"), HuggingFile("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors"), HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors"), @@ -194,40 +219,40 @@ KNOWN_CHECKPOINTS: Final[List[Downloadable]] = [ HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips.safetensors"), HuggingFile("stabilityai/stable-diffusion-3-medium", filename="sd3_medium_incl_clips_t5xxlfp8.safetensors"), HuggingFile("fal/AuraFlow", filename="aura_flow_0.1.safetensors"), -] +], folder_name="checkpoints") -KNOWN_UNCLIP_CHECKPOINTS: Final[List[Downloadable]] = [ +KNOWN_UNCLIP_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"), HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-h.ckpt"), HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-l.ckpt"), -] +], folder_name="checkpoints") -KNOWN_IMAGE_ONLY_CHECKPOINTS: Final[List[Downloadable]] = [ +KNOWN_IMAGE_ONLY_CHECKPOINTS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("stabilityai/stable-zero123", "stable_zero123.ckpt") -] +], folder_name="checkpoints") -KNOWN_UPSCALERS: Final[List[Downloadable]] = [ +KNOWN_UPSCALERS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("lllyasviel/Annotators", "RealESRGAN_x4plus.pth") -] +], folder_name="upscale_models") -KNOWN_GLIGEN_MODELS: Final[List[Downloadable]] = [ +KNOWN_GLIGEN_MODELS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned.safetensors", show_in_ui=False), HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned_fp16.safetensors"), -] +], folder_name="gligen") -KNOWN_CLIP_VISION_MODELS: Final[List[Downloadable]] = [ +KNOWN_CLIP_VISION_MODELS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("comfyanonymous/clip_vision_g", "clip_vision_g.safetensors") -] +], folder_name="clip_vision") -KNOWN_LORAS: Final[List[Downloadable]] = [ +KNOWN_LORAS: Final[KnownDownloadables] = KnownDownloadables([ CivitFile(model_id=211577, model_version_id=238349, filename="openxl_handsfix.safetensors"), CivitFile(model_id=324815, model_version_id=364137, filename="blur_control_xl_v1.safetensors"), CivitFile(model_id=47085, model_version_id=55199, filename="GoodHands-beta2.safetensors"), HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-12steps-CFG-lora.safetensors"), HuggingFile("ByteDance/Hyper-SD", "Hyper-SD15-12steps-CFG-lora.safetensors"), -] +], folder_name="loras") -KNOWN_CONTROLNETS: Final[List[Downloadable]] = [ +KNOWN_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "OpenPoseXL2.safetensors", convert_to_16_bit=True, size=2502139104), HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "control-lora-openposeXL2-rank256.safetensors"), HuggingFile("comfyanonymous/ControlNet-v1-1_fp16_safetensors", "control_lora_rank128_v11e_sd15_ip2p_fp16.safetensors"), @@ -317,9 +342,9 @@ KNOWN_CONTROLNETS: Final[List[Downloadable]] = [ HuggingFile("TheMistoAI/MistoLine", "mistoLine_rank256.safetensors"), HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model_promax.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0-promax.safetensors"), HuggingFile("xinsir/controlnet-union-sdxl-1.0", "diffusion_pytorch_model.safetensors", save_with_filename="xinsir-controlnet-union-sdxl-1.0.safetensors"), -] +], folder_name="controlnet") -KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [ +KNOWN_DIFF_CONTROLNETS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_canny_fp16.safetensors"), HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_depth_fp16.safetensors"), HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_hed_fp16.safetensors"), @@ -328,39 +353,82 @@ KNOWN_DIFF_CONTROLNETS: Final[List[Downloadable]] = [ HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_openpose_fp16.safetensors"), HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_scribble_fp16.safetensors"), HuggingFile("kohya-ss/ControlNet-diff-modules", "diff_control_sd15_seg_fp16.safetensors"), -] +], folder_name="controlnet") -KNOWN_APPROX_VAES: Final[List[Downloadable]] = [ +KNOWN_APPROX_VAES: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("madebyollin/taesd", "taesd_decoder.safetensors"), HuggingFile("madebyollin/taesdxl", "taesdxl_decoder.safetensors"), -] +], folder_name="vae_approx") -KNOWN_VAES: Final[List[Downloadable]] = [ +KNOWN_VAES: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("stabilityai/sdxl-vae", "sdxl_vae.safetensors"), HuggingFile("stabilityai/sd-vae-ft-mse-original", "vae-ft-mse-840000-ema-pruned.safetensors"), -] +], folder_name="vae") KNOWN_HUGGINGFACE_MODEL_REPOS: Final[Set[str]] = { - "JingyeChen22/textdiffuser2_layout_planner", + 'JingyeChen22/textdiffuser2_layout_planner', 'JingyeChen22/textdiffuser2-full-ft', - "microsoft/Phi-3-mini-4k-instruct", - "llava-hf/llava-v1.6-mistral-7b-hf" + 'microsoft/Phi-3-mini-4k-instruct', + 'llava-hf/llava-v1.6-mistral-7b-hf' } -KNOWN_UNET_MODELS: Final[List[Downloadable]] = [ +KNOWN_UNET_MODELS: Final[KnownDownloadables] = KnownDownloadables([ HuggingFile("ByteDance/Hyper-SD", "Hyper-SDXL-1step-Unet-Comfyui.fp16.safetensors") -] +], folder_name="unet") -KNOWN_CLIP_MODELS: Final[List[Downloadable]] = [ +KNOWN_CLIP_MODELS: Final[KnownDownloadables] = KnownDownloadables([ # todo: is this correct? HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp16.safetensors", save_with_filename="t5xxl_fp16.safetensors"), HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/t5xxl_fp8_e4m3fn.safetensors", save_with_filename="t5xxl_fp8_e4m3fn.safetensors"), HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_g.safetensors", save_with_filename="clip_g.safetensors"), HuggingFile("stabilityai/stable-diffusion-3-medium", "text_encoders/clip_l.safetensors", save_with_filename="clip_l.safetensors"), +], folder_name="clip") + +_known_models_db: list[KnownDownloadables] = [ + KNOWN_CHECKPOINTS, + KNOWN_VAES, + KNOWN_LORAS, + KNOWN_UNET_MODELS, + KNOWN_APPROX_VAES, + KNOWN_DIFF_CONTROLNETS, + KNOWN_CLIP_MODELS, + KNOWN_CLIP_VISION_MODELS, + KNOWN_CONTROLNETS, + KNOWN_GLIGEN_MODELS, + KNOWN_IMAGE_ONLY_CHECKPOINTS, + KNOWN_UNCLIP_CHECKPOINTS, + KNOWN_UPSCALERS, ] -def add_known_models(folder_name: str, known_models: List[Downloadable], *models: Downloadable) -> List[Downloadable]: +def _is_known_model_in_models_db(obj: list[Downloadable] | KnownDownloadables): + return any(candidate is obj or candidate.data is obj for candidate in _known_models_db) + + +def _get_known_models_for_folder_name(folder_name: str) -> List[Downloadable]: + return list(chain.from_iterable([candidate for candidate in _known_models_db if candidate.folder_name == folder_name])) + + +def add_known_models(folder_name: str, known_models: Optional[List[Downloadable]] | Downloadable = None, *models: Downloadable) -> MutableSequence[Downloadable]: + if isinstance(known_models, Downloadable): + models = [known_models] + list(models) or [] + known_models = None + + if known_models is None: + try: + known_models = next(candidate for candidate in _known_models_db if candidate.folder_name == folder_name) + except StopIteration: + add_model_folder_path(folder_name, extensions=supported_pt_extensions) + known_models = KnownDownloadables([], folder_name=folder_name) + + # check if any of the pre-existing known models already reference this list + if not _is_known_model_in_models_db(known_models): + if not isinstance(known_models, KnownDownloadables): + # wrap it + known_models = KnownDownloadables(known_models) + # meets protocol at this point + _known_models_db.append(known_models) + if len(models) < 1: return known_models @@ -368,7 +436,7 @@ def add_known_models(folder_name: str, known_models: List[Downloadable], *models logging.warning(f"Known models have been disabled in the options (while adding {folder_name}/{','.join(map(str, models))})") pre_existing = frozenset(known_models) - known_models += [model for model in models if model not in pre_existing] + known_models.extend([model for model in models if model not in pre_existing]) folder_paths.invalidate_cache(folder_name) return known_models diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index 2674bb807..158ee7c6e 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -1,12 +1,40 @@ from __future__ import annotations import dataclasses +import functools from os.path import split +from pathlib import PurePosixPath from typing import Optional, List, Sequence, Union +from can_ada import parse, URL from typing_extensions import TypedDict, NotRequired +@dataclasses.dataclass(frozen=True) +class UrlFile: + _url: str + _save_with_filename: Optional[str] = None + + def __str__(self): + return self.save_with_filename + + @functools.cached_property + def url(self) -> str: + return self._url + + @functools.cached_property + def parsed_url(self) -> URL: + return parse(self._url) + + @property + def save_with_filename(self) -> str: + return self._save_with_filename or self.filename + + @property + def filename(self) -> str: + return PurePosixPath(self.parsed_url.pathname).name + + @dataclasses.dataclass(frozen=True) class CivitFile: """ @@ -154,4 +182,4 @@ class CivitModelsGetResponse(TypedDict): modelVersions: List[CivitModelVersion] -Downloadable = Union[CivitFile | HuggingFile] +Downloadable = Union[CivitFile | HuggingFile | UrlFile]