mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Improve model downloading
This commit is contained in:
parent
73cbb5bfad
commit
e68f8885e3
@ -2,6 +2,8 @@ import os
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional
|
||||
|
||||
from pkg_resources import resource_filename
|
||||
|
||||
from ..cli_args import args
|
||||
@ -131,9 +133,11 @@ def exists_annotated_filepath(name):
|
||||
return os.path.exists(filepath)
|
||||
|
||||
|
||||
def add_model_folder_path(folder_name, full_folder_path):
|
||||
def add_model_folder_path(folder_name, full_folder_path: Optional[str] = None):
|
||||
global folder_names_and_paths
|
||||
if folder_name in folder_names_and_paths:
|
||||
if full_folder_path is None:
|
||||
full_folder_path = os.path.join(models_dir, folder_name)
|
||||
if folder_name in folder_names_and_paths and full_folder_path not in folder_names_and_paths[folder_name][0]:
|
||||
folder_names_and_paths[folder_name][0].append(full_folder_path)
|
||||
else:
|
||||
folder_names_and_paths[folder_name] = ([full_folder_path], set())
|
||||
|
||||
@ -1,53 +1,71 @@
|
||||
import dataclasses
|
||||
from typing import List, Optional
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from os.path import join
|
||||
from typing import List, Any, Optional
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
from requests import Session
|
||||
|
||||
from .cmd import folder_paths
|
||||
from .utils import comfy_tqdm
|
||||
from posixpath import split
|
||||
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse
|
||||
from .utils import comfy_tqdm, ProgressBar
|
||||
|
||||
session = Session()
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HuggingFile:
|
||||
"""
|
||||
A file on Huggingface Hub
|
||||
|
||||
Attributes:
|
||||
repo_id (str): The Huggingface repository of a known file
|
||||
filename (str): The path to the known file in the repository
|
||||
show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter
|
||||
"""
|
||||
repo_id: str
|
||||
filename: str
|
||||
show_in_ui: Optional[bool] = True
|
||||
|
||||
def __str__(self):
|
||||
return split(self.filename)[-1]
|
||||
|
||||
|
||||
def get_filename_list_with_downloadable(folder_name: str, known_huggingface_files: List[HuggingFile]) -> List[str]:
|
||||
def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]:
|
||||
existing = frozenset(folder_paths.get_filename_list(folder_name))
|
||||
downloadable = frozenset(str(f) for f in known_huggingface_files)
|
||||
downloadable = frozenset(str(f) for f in known_files)
|
||||
return sorted(list(existing | downloadable))
|
||||
|
||||
|
||||
def get_or_download(folder_name: str, filename: str, known_huggingface_files: List[HuggingFile]) -> str:
|
||||
def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFile | CivitFile]) -> str:
|
||||
path = folder_paths.get_full_path(folder_name, filename)
|
||||
|
||||
if path is None:
|
||||
try:
|
||||
destination = folder_paths.get_folder_paths(folder_name)[0]
|
||||
hugging_file = next(f for f in known_huggingface_files if str(f) == filename)
|
||||
known_file = next(f for f in known_files if str(f) == filename)
|
||||
with comfy_tqdm():
|
||||
path = hf_hub_download(repo_id=hugging_file.repo_id,
|
||||
filename=hugging_file.filename,
|
||||
local_dir=destination,
|
||||
resume_download=True)
|
||||
if isinstance(known_file, HuggingFile):
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
local_dir=destination,
|
||||
resume_download=True)
|
||||
else:
|
||||
url: Optional[str] = None
|
||||
|
||||
if isinstance(known_file, CivitFile):
|
||||
model_info_res = session.get(
|
||||
f"https://civitai.com/api/v1/models/{known_file.model_id}?modelVersionId={known_file.model_version_id}")
|
||||
model_info: CivitModelsGetResponse = model_info_res.json()
|
||||
for model_version in model_info['modelVersions']:
|
||||
for file in model_version['files']:
|
||||
if file['name'] == filename:
|
||||
url = file['downloadUrl']
|
||||
break
|
||||
if url is not None:
|
||||
break
|
||||
else:
|
||||
raise RuntimeError("unknown file type")
|
||||
|
||||
if url is None:
|
||||
logging.warning(f"Could not retrieve file {str(known_file)}")
|
||||
else:
|
||||
with session.get(url, stream=True, allow_redirects=True) as response:
|
||||
total_size = int(response.headers.get("content-length", 0))
|
||||
progress_bar = ProgressBar(total=total_size)
|
||||
with open(join(destination, filename), "wb") as file:
|
||||
for chunk in response.iter_content(chunk_size=512 * 1024):
|
||||
progress_bar.update(len(chunk))
|
||||
file.write(chunk)
|
||||
path = folder_paths.get_full_path(folder_name, filename)
|
||||
assert path is not None
|
||||
except StopIteration:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
except Exception as exc:
|
||||
logging.error("Error while trying to download a file", exc_info=exc)
|
||||
return path
|
||||
|
||||
|
||||
@ -69,6 +87,9 @@ KNOWN_CHECKPOINTS = [
|
||||
HuggingFile("jomcs/NeverEnding_Dream-Feb19-2023", "CarDos Anime/cardosAnime_v10.safetensors", show_in_ui=False),
|
||||
# from https://github.com/comfyanonymous/ComfyUI_examples/blob/master/area_composition/README.md
|
||||
HuggingFile("ckpt/anything-v3.0", "Anything-V3.0.ckpt", show_in_ui=False),
|
||||
# from https://github.com/huchenlei/ComfyUI-layerdiffuse
|
||||
CivitFile(133005, 357609, filename="juggernautXL_v8Rundiffusion.safetensors"),
|
||||
CivitFile(133005, 357609, filename="juggernautXL_v9Rundiffusionphoto2.safetensors"),
|
||||
]
|
||||
|
||||
KNOWN_UNCLIP_CHECKPOINTS = [
|
||||
|
||||
139
comfy/model_downloader_types.py
Normal file
139
comfy/model_downloader_types.py
Normal file
@ -0,0 +1,139 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
from os.path import split
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class CivitFile:
|
||||
"""
|
||||
A file on CivitAI
|
||||
|
||||
Attributes:
|
||||
model_id (int): The ID of the model
|
||||
model_version_id (int): The version
|
||||
filename (str): The name of the file in the model
|
||||
"""
|
||||
model_id: int
|
||||
model_version_id: int
|
||||
filename: str
|
||||
|
||||
def __str__(self):
|
||||
return self.filename
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HuggingFile:
|
||||
"""
|
||||
A file on Huggingface Hub
|
||||
|
||||
Attributes:
|
||||
repo_id (str): The Huggingface repository of a known file
|
||||
filename (str): The path to the known file in the repository
|
||||
show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter
|
||||
"""
|
||||
repo_id: str
|
||||
filename: str
|
||||
show_in_ui: Optional[bool] = True
|
||||
|
||||
def __str__(self):
|
||||
return split(self.filename)[-1]
|
||||
|
||||
|
||||
from typing import TypedDict, List, Optional, NotRequired
|
||||
|
||||
|
||||
class CivitStats(TypedDict):
|
||||
downloadCount: int
|
||||
favoriteCount: NotRequired[int]
|
||||
thumbsUpCount: int
|
||||
thumbsDownCount: int
|
||||
commentCount: int
|
||||
ratingCount: int
|
||||
rating: float
|
||||
tippedAmountCount: NotRequired[int]
|
||||
|
||||
|
||||
class CivitCreator(TypedDict):
|
||||
username: str
|
||||
image: str
|
||||
|
||||
|
||||
class CivitFileMetadata(TypedDict, total=False):
|
||||
fp: Optional[str]
|
||||
size: Optional[str]
|
||||
format: Optional[str]
|
||||
|
||||
|
||||
class CivitFile_(TypedDict):
|
||||
id: int
|
||||
sizeKB: float
|
||||
name: str
|
||||
type: str
|
||||
metadata: CivitFileMetadata
|
||||
pickleScanResult: str
|
||||
pickleScanMessage: Optional[str]
|
||||
virusScanResult: str
|
||||
virusScanMessage: Optional[str]
|
||||
scannedAt: str
|
||||
hashes: dict
|
||||
downloadUrl: str
|
||||
primary: bool
|
||||
|
||||
|
||||
class CivitImageMetadata(TypedDict):
|
||||
hash: str
|
||||
size: int
|
||||
width: int
|
||||
height: int
|
||||
|
||||
|
||||
class CivitImage(TypedDict):
|
||||
url: str
|
||||
nsfw: str
|
||||
width: int
|
||||
height: int
|
||||
hash: str
|
||||
type: str
|
||||
metadata: CivitImageMetadata
|
||||
availability: str
|
||||
|
||||
|
||||
class CivitModelVersion(TypedDict):
|
||||
id: int
|
||||
modelId: int
|
||||
name: str
|
||||
createdAt: str
|
||||
updatedAt: str
|
||||
status: str
|
||||
publishedAt: str
|
||||
trainedWords: List[str]
|
||||
trainingStatus: NotRequired[Optional[str]]
|
||||
trainingDetails: NotRequired[Optional[str]]
|
||||
baseModel: str
|
||||
baseModelType: str
|
||||
earlyAccessTimeFrame: int
|
||||
description: str
|
||||
vaeId: NotRequired[Optional[int]]
|
||||
stats: CivitStats
|
||||
files: List[CivitFile_]
|
||||
images: List[CivitImage]
|
||||
downloadUrl: str
|
||||
|
||||
|
||||
class CivitModelsGetResponse(TypedDict):
|
||||
id: int
|
||||
name: str
|
||||
description: str
|
||||
type: str
|
||||
poi: bool
|
||||
nsfw: bool
|
||||
allowNoCredit: bool
|
||||
allowCommercialUse: List[str]
|
||||
allowDerivatives: bool
|
||||
allowDifferentLicense: bool
|
||||
stats: CivitStats
|
||||
creator: CivitCreator
|
||||
tags: List[str]
|
||||
modelVersions: List[CivitModelVersion]
|
||||
@ -23,8 +23,9 @@ from .. import model_management
|
||||
from ..cli_args import args
|
||||
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..model_downloader import HuggingFile, get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
|
||||
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS
|
||||
from ..model_downloader_types import HuggingFile
|
||||
from ..nodes.common import MAX_RESOLUTION
|
||||
from .. import controlnet
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user