Improve model downloading

This commit is contained in:
doctorpangloss 2024-03-12 15:25:23 -07:00
parent 73cbb5bfad
commit e68f8885e3
4 changed files with 200 additions and 35 deletions

View File

@ -2,6 +2,8 @@ import os
import sys
import time
import logging
from typing import Optional
from pkg_resources import resource_filename
from ..cli_args import args
@ -131,9 +133,11 @@ def exists_annotated_filepath(name):
return os.path.exists(filepath)
def add_model_folder_path(folder_name, full_folder_path):
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None):
global folder_names_and_paths
if folder_name in folder_names_and_paths:
if full_folder_path is None:
full_folder_path = os.path.join(models_dir, folder_name)
if folder_name in folder_names_and_paths and full_folder_path not in folder_names_and_paths[folder_name][0]:
folder_names_and_paths[folder_name][0].append(full_folder_path)
else:
folder_names_and_paths[folder_name] = ([full_folder_path], set())

View File

@ -1,53 +1,71 @@
import dataclasses
from typing import List, Optional
from __future__ import annotations
import logging
from os.path import join
from typing import List, Any, Optional
from huggingface_hub import hf_hub_download
from requests import Session
from .cmd import folder_paths
from .utils import comfy_tqdm
from posixpath import split
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse
from .utils import comfy_tqdm, ProgressBar
session = Session()
@dataclasses.dataclass
class HuggingFile:
"""
A file on Huggingface Hub
Attributes:
repo_id (str): The Huggingface repository of a known file
filename (str): The path to the known file in the repository
show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter
"""
repo_id: str
filename: str
show_in_ui: Optional[bool] = True
def __str__(self):
return split(self.filename)[-1]
def get_filename_list_with_downloadable(folder_name: str, known_huggingface_files: List[HuggingFile]) -> List[str]:
def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]:
existing = frozenset(folder_paths.get_filename_list(folder_name))
downloadable = frozenset(str(f) for f in known_huggingface_files)
downloadable = frozenset(str(f) for f in known_files)
return sorted(list(existing | downloadable))
def get_or_download(folder_name: str, filename: str, known_huggingface_files: List[HuggingFile]) -> str:
def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> str:
path = folder_paths.get_full_path(folder_name, filename)
if path is None:
try:
destination = folder_paths.get_folder_paths(folder_name)[0]
hugging_file = next(f for f in known_huggingface_files if str(f) == filename)
known_file = next(f for f in known_files if str(f) == filename)
with comfy_tqdm():
path = hf_hub_download(repo_id=hugging_file.repo_id,
filename=hugging_file.filename,
local_dir=destination,
resume_download=True)
if isinstance(known_file, HuggingFile):
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
local_dir=destination,
resume_download=True)
else:
url: Optional[str] = None
if isinstance(known_file, CivitFile):
model_info_res = session.get(
f"https://civitai.com/api/v1/models/{known_file.model_id}?modelVersionId={known_file.model_version_id}")
model_info: CivitModelsGetResponse = model_info_res.json()
for model_version in model_info['modelVersions']:
for file in model_version['files']:
if file['name'] == filename:
url = file['downloadUrl']
break
if url is not None:
break
else:
raise RuntimeError("unknown file type")
if url is None:
logging.warning(f"Could not retrieve file {str(known_file)}")
else:
with session.get(url, stream=True, allow_redirects=True) as response:
total_size = int(response.headers.get("content-length", 0))
progress_bar = ProgressBar(total=total_size)
with open(join(destination, filename), "wb") as file:
for chunk in response.iter_content(chunk_size=512 * 1024):
progress_bar.update(len(chunk))
file.write(chunk)
path = folder_paths.get_full_path(folder_name, filename)
assert path is not None
except StopIteration:
pass
except Exception:
pass
except Exception as exc:
logging.error("Error while trying to download a file", exc_info=exc)
return path
@ -69,6 +87,9 @@ KNOWN_CHECKPOINTS = [
HuggingFile("jomcs/NeverEnding_Dream-Feb19-2023", "CarDos Anime/cardosAnime_v10.safetensors", show_in_ui=False),
# from https://github.com/comfyanonymous/ComfyUI_examples/blob/master/area_composition/README.md
HuggingFile("ckpt/anything-v3.0", "Anything-V3.0.ckpt", show_in_ui=False),
# from https://github.com/huchenlei/ComfyUI-layerdiffuse
CivitFile(133005, 357609, filename="juggernautXL_v8Rundiffusion.safetensors"),
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
]
KNOWN_UNCLIP_CHECKPOINTS = [

View File

@ -0,0 +1,139 @@
from __future__ import annotations
import dataclasses
from os.path import split
from typing import Optional
@dataclasses.dataclass
class CivitFile:
"""
A file on CivitAI
Attributes:
model_id (int): The ID of the model
model_version_id (int): The version
filename (str): The name of the file in the model
"""
model_id: int
model_version_id: int
filename: str
def __str__(self):
return self.filename
@dataclasses.dataclass
class HuggingFile:
"""
A file on Huggingface Hub
Attributes:
repo_id (str): The Huggingface repository of a known file
filename (str): The path to the known file in the repository
show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter
"""
repo_id: str
filename: str
show_in_ui: Optional[bool] = True
def __str__(self):
return split(self.filename)[-1]
from typing import TypedDict, List, Optional, NotRequired
class CivitStats(TypedDict):
downloadCount: int
favoriteCount: NotRequired[int]
thumbsUpCount: int
thumbsDownCount: int
commentCount: int
ratingCount: int
rating: float
tippedAmountCount: NotRequired[int]
class CivitCreator(TypedDict):
username: str
image: str
class CivitFileMetadata(TypedDict, total=False):
fp: Optional[str]
size: Optional[str]
format: Optional[str]
class CivitFile_(TypedDict):
id: int
sizeKB: float
name: str
type: str
metadata: CivitFileMetadata
pickleScanResult: str
pickleScanMessage: Optional[str]
virusScanResult: str
virusScanMessage: Optional[str]
scannedAt: str
hashes: dict
downloadUrl: str
primary: bool
class CivitImageMetadata(TypedDict):
hash: str
size: int
width: int
height: int
class CivitImage(TypedDict):
url: str
nsfw: str
width: int
height: int
hash: str
type: str
metadata: CivitImageMetadata
availability: str
class CivitModelVersion(TypedDict):
id: int
modelId: int
name: str
createdAt: str
updatedAt: str
status: str
publishedAt: str
trainedWords: List[str]
trainingStatus: NotRequired[Optional[str]]
trainingDetails: NotRequired[Optional[str]]
baseModel: str
baseModelType: str
earlyAccessTimeFrame: int
description: str
vaeId: NotRequired[Optional[int]]
stats: CivitStats
files: List[CivitFile_]
images: List[CivitImage]
downloadUrl: str
class CivitModelsGetResponse(TypedDict):
id: int
name: str
description: str
type: str
poi: bool
nsfw: bool
allowNoCredit: bool
allowCommercialUse: List[str]
allowDerivatives: bool
allowDifferentLicense: bool
stats: CivitStats
creator: CivitCreator
tags: List[str]
modelVersions: List[CivitModelVersion]

View File

@ -23,8 +23,9 @@ from .. import model_management
from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..model_downloader import HuggingFile, get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS
from ..model_downloader_types import HuggingFile
from ..nodes.common import MAX_RESOLUTION
from .. import controlnet