mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
Download known models from HuggingFace
This commit is contained in:
parent
175a50d7ba
commit
3c57ef831c
95
comfy/model_downloader.py
Normal file
95
comfy/model_downloader.py
Normal file
@ -0,0 +1,95 @@
|
||||
import dataclasses
|
||||
from typing import List, Optional
|
||||
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from .cmd import folder_paths
|
||||
from .utils import comfy_tqdm
|
||||
from posixpath import split
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class HuggingFile:
|
||||
"""
|
||||
A file on Huggingface Hub
|
||||
|
||||
Attributes:
|
||||
repo_id (str): The Huggingface repository of a known file
|
||||
filename (str): The path to the known file in the repository
|
||||
show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter
|
||||
"""
|
||||
repo_id: str
|
||||
filename: str
|
||||
show_in_ui: Optional[bool] = True
|
||||
|
||||
def __str__(self):
|
||||
return split(self.filename)[-1]
|
||||
|
||||
|
||||
def get_filename_list_with_downloadable(folder_name: str, known_huggingface_files: List[HuggingFile]) -> List[str]:
|
||||
existing = frozenset(folder_paths.get_filename_list(folder_name))
|
||||
downloadable = frozenset(str(f) for f in known_huggingface_files)
|
||||
return sorted(list(existing | downloadable))
|
||||
|
||||
|
||||
def get_or_download(folder_name: str, filename: str, known_huggingface_files: List[HuggingFile]) -> str:
|
||||
path = folder_paths.get_full_path(folder_name, filename)
|
||||
|
||||
if path is None:
|
||||
try:
|
||||
destination = folder_paths.get_folder_paths(folder_name)[0]
|
||||
hugging_file = next(f for f in known_huggingface_files if str(f) == filename)
|
||||
with comfy_tqdm():
|
||||
path = hf_hub_download(repo_id=hugging_file.repo_id,
|
||||
filename=hugging_file.filename,
|
||||
local_dir=destination,
|
||||
resume_download=True)
|
||||
except StopIteration:
|
||||
pass
|
||||
except Exception:
|
||||
pass
|
||||
return path
|
||||
|
||||
|
||||
KNOWN_CHECKPOINTS = [
|
||||
HuggingFile("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors"),
|
||||
HuggingFile("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors"),
|
||||
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors"),
|
||||
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0.safetensors", show_in_ui=False),
|
||||
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_b.safetensors"),
|
||||
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
|
||||
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stage_a.safetensors"),
|
||||
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors"),
|
||||
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned-emaonly.ckpt", show_in_ui=False),
|
||||
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned.ckpt", show_in_ui=False),
|
||||
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned.safetensors", show_in_ui=False),
|
||||
# from https://github.com/comfyanonymous/ComfyUI_examples/tree/master/2_pass_txt2img
|
||||
HuggingFile("stabilityai/stable-diffusion-2-1", "v2-1_768-ema-pruned.ckpt", show_in_ui=False),
|
||||
HuggingFile("waifu-diffusion/wd-1-5-beta3", "wd-illusion-fp16.safetensors", show_in_ui=False),
|
||||
HuggingFile("jomcs/NeverEnding_Dream-Feb19-2023", "CarDos Anime/cardosAnime_v10.safetensors", show_in_ui=False),
|
||||
# from https://github.com/comfyanonymous/ComfyUI_examples/blob/master/area_composition/README.md
|
||||
HuggingFile("ckpt/anything-v3.0", "Anything-V3.0.ckpt", show_in_ui=False),
|
||||
]
|
||||
|
||||
KNOWN_UNCLIP_CHECKPOINTS = [
|
||||
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
|
||||
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-h.ckpt"),
|
||||
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-l.ckpt"),
|
||||
]
|
||||
|
||||
KNOWN_IMAGE_ONLY_CHECKPOINTS = [
|
||||
HuggingFile("stabilityai/stable-zero123", "stable_zero123.ckpt")
|
||||
]
|
||||
|
||||
KNOWN_UPSCALERS = [
|
||||
HuggingFile("lllyasviel/Annotators", "RealESRGAN_x4plus.pth")
|
||||
]
|
||||
|
||||
KNOWN_GLIGEN_MODELS = [
|
||||
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned.safetensors"),
|
||||
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned_fp16.safetensors"),
|
||||
]
|
||||
|
||||
KNOWN_CLIP_VISION_MODELS = [
|
||||
HuggingFile("comfyanonymous/clip_vision_g", "clip_vision_g.safetensors")
|
||||
]
|
||||
@ -22,6 +22,8 @@ from .. import model_management
|
||||
from ..cli_args import args
|
||||
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..model_downloader import HuggingFile, get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
|
||||
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS
|
||||
from ..nodes.common import MAX_RESOLUTION
|
||||
from .. import controlnet
|
||||
|
||||
@ -497,7 +499,7 @@ class CheckpointLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "config_name": (folder_paths.get_filename_list("configs"),),
|
||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),)}}
|
||||
"ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),)}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
|
||||
@ -505,13 +507,13 @@ class CheckpointLoader:
|
||||
|
||||
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
|
||||
config_path = folder_paths.get_full_path("configs", config_name)
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
||||
return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
|
||||
class CheckpointLoaderSimple:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
|
||||
return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
@ -519,7 +521,7 @@ class CheckpointLoaderSimple:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
||||
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return out[:3]
|
||||
|
||||
@ -553,7 +555,7 @@ class DiffusersLoader:
|
||||
class unCLIPCheckpointLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
|
||||
return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_UNCLIP_CHECKPOINTS),),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
|
||||
FUNCTION = "load_checkpoint"
|
||||
@ -561,7 +563,7 @@ class unCLIPCheckpointLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS)
|
||||
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return out
|
||||
|
||||
@ -861,7 +863,7 @@ class DualCLIPLoader:
|
||||
class CLIPVisionLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"),),
|
||||
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip_vision", KNOWN_CLIP_VISION_MODELS),),
|
||||
}}
|
||||
RETURN_TYPES = ("CLIP_VISION",)
|
||||
FUNCTION = "load_clip"
|
||||
@ -869,7 +871,7 @@ class CLIPVisionLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_clip(self, clip_name):
|
||||
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
||||
clip_path = get_or_download("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS)
|
||||
clip_vision = clip_vision_module.load(clip_path)
|
||||
return (clip_vision,)
|
||||
|
||||
@ -956,7 +958,7 @@ class unCLIPConditioning:
|
||||
class GLIGENLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"),)}}
|
||||
return {"required": { "gligen_name": (get_filename_list_with_downloadable("gligen", KNOWN_GLIGEN_MODELS),)}}
|
||||
|
||||
RETURN_TYPES = ("GLIGEN",)
|
||||
FUNCTION = "load_gligen"
|
||||
@ -964,7 +966,7 @@ class GLIGENLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_gligen(self, gligen_name):
|
||||
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
||||
gligen_path = get_or_download("gligen", gligen_name, KNOWN_GLIGEN_MODELS)
|
||||
gligen = sd.load_gligen(gligen_path)
|
||||
return (gligen,)
|
||||
|
||||
|
||||
@ -7,6 +7,9 @@ from . import checkpoint_pickle
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm import tqdm
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||
if device is None:
|
||||
@ -486,3 +489,30 @@ class ProgressBar:
|
||||
|
||||
def get_project_root() -> str:
|
||||
return os.path.join(os.path.dirname(__file__), "..")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def comfy_tqdm():
|
||||
"""
|
||||
Monky patches child calls to tqdm and sends the progress to the UI
|
||||
:return:
|
||||
"""
|
||||
_original_init = tqdm.__init__
|
||||
_original_update = tqdm.update
|
||||
try:
|
||||
def __init(self, *args, **kwargs):
|
||||
_original_init(self, *args, **kwargs)
|
||||
self._progress_bar = ProgressBar(self.total)
|
||||
|
||||
def __update(self, n=1):
|
||||
assert self._progress_bar is not None
|
||||
_original_update(self, n)
|
||||
self._progress_bar.update(n)
|
||||
|
||||
tqdm.__init__ = __init
|
||||
tqdm.update = __update
|
||||
yield
|
||||
finally:
|
||||
# Restore original tqdm
|
||||
tqdm.__init__ = _original_init
|
||||
tqdm.update = _original_update
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
|
||||
from ..chainner_models import model_loading
|
||||
from comfy import model_management
|
||||
import torch
|
||||
@ -8,7 +9,7 @@ from comfy.cmd import folder_paths
|
||||
class UpscaleModelLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"model_name": (folder_paths.get_filename_list("upscale_models"),),
|
||||
return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models", KNOWN_UPSCALERS),),
|
||||
}}
|
||||
|
||||
RETURN_TYPES = ("UPSCALE_MODEL",)
|
||||
@ -17,7 +18,7 @@ class UpscaleModelLoader:
|
||||
CATEGORY = "loaders"
|
||||
|
||||
def load_model(self, model_name):
|
||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
||||
model_path = get_or_download("upscale_models", model_name, KNOWN_UPSCALERS)
|
||||
sd = utils.load_torch_file(model_path, safe_load=True)
|
||||
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||
sd = utils.state_dict_prefix_replace(sd, {"module.": ""})
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_IMAGE_ONLY_CHECKPOINTS, get_or_download
|
||||
from comfy.nodes.common import MAX_RESOLUTION
|
||||
import torch
|
||||
import comfy.utils
|
||||
@ -9,7 +10,7 @@ from . import nodes_model_merging
|
||||
class ImageOnlyCheckpointLoader:
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
||||
return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_IMAGE_ONLY_CHECKPOINTS), ),
|
||||
}}
|
||||
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
||||
FUNCTION = "load_checkpoint"
|
||||
@ -17,7 +18,7 @@ class ImageOnlyCheckpointLoader:
|
||||
CATEGORY = "loaders/video_models"
|
||||
|
||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
||||
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_IMAGE_ONLY_CHECKPOINTS)
|
||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||
return (out[0], out[3], out[2])
|
||||
|
||||
|
||||
@ -30,4 +30,5 @@ ConfigArgParse
|
||||
aio-pika
|
||||
pyjwt[crypto]
|
||||
kornia>=0.7.1
|
||||
mpmath>=1.0,!=1.4.0a0
|
||||
mpmath>=1.0,!=1.4.0a0
|
||||
huggingface_hub
|
||||
Loading…
Reference in New Issue
Block a user