From e68f8885e3e13186bc8e1721a272d824273b6d43 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 12 Mar 2024 15:25:23 -0700 Subject: [PATCH] Improve model downloading --- comfy/cmd/folder_paths.py | 8 +- comfy/model_downloader.py | 85 +++++++++++-------- comfy/model_downloader_types.py | 139 ++++++++++++++++++++++++++++++++ comfy/nodes/base_nodes.py | 3 +- 4 files changed, 200 insertions(+), 35 deletions(-) create mode 100644 comfy/model_downloader_types.py diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index 05d0b625d..8c26b228c 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -2,6 +2,8 @@ import os import sys import time import logging +from typing import Optional + from pkg_resources import resource_filename from ..cli_args import args @@ -131,9 +133,11 @@ def exists_annotated_filepath(name): return os.path.exists(filepath) -def add_model_folder_path(folder_name, full_folder_path): +def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None): global folder_names_and_paths - if folder_name in folder_names_and_paths: + if full_folder_path is None: + full_folder_path = os.path.join(models_dir, folder_name) + if folder_name in folder_names_and_paths and full_folder_path not in folder_names_and_paths[folder_name][0]: folder_names_and_paths[folder_name][0].append(full_folder_path) else: folder_names_and_paths[folder_name] = ([full_folder_path], set()) diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 9bed8907c..330e39429 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -1,53 +1,71 @@ -import dataclasses -from typing import List, Optional +from __future__ import annotations + +import logging +from os.path import join +from typing import List, Any, Optional from huggingface_hub import hf_hub_download +from requests import Session from .cmd import folder_paths -from .utils import comfy_tqdm -from posixpath import split +from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse +from .utils import comfy_tqdm, ProgressBar + +session = Session() -@dataclasses.dataclass -class HuggingFile: - """ - A file on Huggingface Hub - - Attributes: - repo_id (str): The Huggingface repository of a known file - filename (str): The path to the known file in the repository - show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter - """ - repo_id: str - filename: str - show_in_ui: Optional[bool] = True - - def __str__(self): - return split(self.filename)[-1] - - -def get_filename_list_with_downloadable(folder_name: str, known_huggingface_files: List[HuggingFile]) -> List[str]: +def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]: existing = frozenset(folder_paths.get_filename_list(folder_name)) - downloadable = frozenset(str(f) for f in known_huggingface_files) + downloadable = frozenset(str(f) for f in known_files) return sorted(list(existing | downloadable)) -def get_or_download(folder_name: str, filename: str, known_huggingface_files: List[HuggingFile]) -> str: +def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> str: path = folder_paths.get_full_path(folder_name, filename) if path is None: try: destination = folder_paths.get_folder_paths(folder_name)[0] - hugging_file = next(f for f in known_huggingface_files if str(f) == filename) + known_file = next(f for f in known_files if str(f) == filename) with comfy_tqdm(): - path = hf_hub_download(repo_id=hugging_file.repo_id, - filename=hugging_file.filename, - local_dir=destination, - resume_download=True) + if isinstance(known_file, HuggingFile): + path = hf_hub_download(repo_id=known_file.repo_id, + filename=known_file.filename, + local_dir=destination, + resume_download=True) + else: + url: Optional[str] = None + + if isinstance(known_file, CivitFile): + model_info_res = session.get( + f"https://civitai.com/api/v1/models/{known_file.model_id}?modelVersionId={known_file.model_version_id}") + model_info: CivitModelsGetResponse = model_info_res.json() + for model_version in model_info['modelVersions']: + for file in model_version['files']: + if file['name'] == filename: + url = file['downloadUrl'] + break + if url is not None: + break + else: + raise RuntimeError("unknown file type") + + if url is None: + logging.warning(f"Could not retrieve file {str(known_file)}") + else: + with session.get(url, stream=True, allow_redirects=True) as response: + total_size = int(response.headers.get("content-length", 0)) + progress_bar = ProgressBar(total=total_size) + with open(join(destination, filename), "wb") as file: + for chunk in response.iter_content(chunk_size=512 * 1024): + progress_bar.update(len(chunk)) + file.write(chunk) + path = folder_paths.get_full_path(folder_name, filename) + assert path is not None except StopIteration: pass - except Exception: - pass + except Exception as exc: + logging.error("Error while trying to download a file", exc_info=exc) return path @@ -69,6 +87,9 @@ KNOWN_CHECKPOINTS = [ HuggingFile("jomcs/NeverEnding_Dream-Feb19-2023", "CarDos Anime/cardosAnime_v10.safetensors", show_in_ui=False), # from https://github.com/comfyanonymous/ComfyUI_examples/blob/master/area_composition/README.md HuggingFile("ckpt/anything-v3.0", "Anything-V3.0.ckpt", show_in_ui=False), + # from https://github.com/huchenlei/ComfyUI-layerdiffuse + CivitFile(133005, 357609, filename="juggernautXL_v8Rundiffusion.safetensors"), + CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"), ] KNOWN_UNCLIP_CHECKPOINTS = [ diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py new file mode 100644 index 000000000..99f11a6be --- /dev/null +++ b/comfy/model_downloader_types.py @@ -0,0 +1,139 @@ +from __future__ import annotations + +import dataclasses +from os.path import split +from typing import Optional + + +@dataclasses.dataclass +class CivitFile: + """ + A file on CivitAI + + Attributes: + model_id (int): The ID of the model + model_version_id (int): The version + filename (str): The name of the file in the model + """ + model_id: int + model_version_id: int + filename: str + + def __str__(self): + return self.filename + + +@dataclasses.dataclass +class HuggingFile: + """ + A file on Huggingface Hub + + Attributes: + repo_id (str): The Huggingface repository of a known file + filename (str): The path to the known file in the repository + show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter + """ + repo_id: str + filename: str + show_in_ui: Optional[bool] = True + + def __str__(self): + return split(self.filename)[-1] + + +from typing import TypedDict, List, Optional, NotRequired + + +class CivitStats(TypedDict): + downloadCount: int + favoriteCount: NotRequired[int] + thumbsUpCount: int + thumbsDownCount: int + commentCount: int + ratingCount: int + rating: float + tippedAmountCount: NotRequired[int] + + +class CivitCreator(TypedDict): + username: str + image: str + + +class CivitFileMetadata(TypedDict, total=False): + fp: Optional[str] + size: Optional[str] + format: Optional[str] + + +class CivitFile_(TypedDict): + id: int + sizeKB: float + name: str + type: str + metadata: CivitFileMetadata + pickleScanResult: str + pickleScanMessage: Optional[str] + virusScanResult: str + virusScanMessage: Optional[str] + scannedAt: str + hashes: dict + downloadUrl: str + primary: bool + + +class CivitImageMetadata(TypedDict): + hash: str + size: int + width: int + height: int + + +class CivitImage(TypedDict): + url: str + nsfw: str + width: int + height: int + hash: str + type: str + metadata: CivitImageMetadata + availability: str + + +class CivitModelVersion(TypedDict): + id: int + modelId: int + name: str + createdAt: str + updatedAt: str + status: str + publishedAt: str + trainedWords: List[str] + trainingStatus: NotRequired[Optional[str]] + trainingDetails: NotRequired[Optional[str]] + baseModel: str + baseModelType: str + earlyAccessTimeFrame: int + description: str + vaeId: NotRequired[Optional[int]] + stats: CivitStats + files: List[CivitFile_] + images: List[CivitImage] + downloadUrl: str + + +class CivitModelsGetResponse(TypedDict): + id: int + name: str + description: str + type: str + poi: bool + nsfw: bool + allowNoCredit: bool + allowCommercialUse: List[str] + allowDerivatives: bool + allowDifferentLicense: bool + stats: CivitStats + creator: CivitCreator + tags: List[str] + modelVersions: List[CivitModelVersion] diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 550f5f8e5..a0f39f865 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -23,8 +23,9 @@ from .. import model_management from ..cli_args import args from ..cmd import folder_paths, latent_preview -from ..model_downloader import HuggingFile, get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \ +from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \ KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS +from ..model_downloader_types import HuggingFile from ..nodes.common import MAX_RESOLUTION from .. import controlnet