From 3c57ef831cfc9bfb343e5649dd0611bc52656716 Mon Sep 17 00:00:00 2001 From: Benjamin Berman Date: Mon, 11 Mar 2024 00:15:06 -0700 Subject: [PATCH] Download known models from HuggingFace --- comfy/model_downloader.py | 95 +++++++++++++++++++++++ comfy/nodes/base_nodes.py | 22 +++--- comfy/utils.py | 30 +++++++ comfy_extras/nodes/nodes_upscale_model.py | 5 +- comfy_extras/nodes/nodes_video_model.py | 5 +- requirements.txt | 3 +- 6 files changed, 145 insertions(+), 15 deletions(-) create mode 100644 comfy/model_downloader.py diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py new file mode 100644 index 000000000..9bed8907c --- /dev/null +++ b/comfy/model_downloader.py @@ -0,0 +1,95 @@ +import dataclasses +from typing import List, Optional + +from huggingface_hub import hf_hub_download + +from .cmd import folder_paths +from .utils import comfy_tqdm +from posixpath import split + + +@dataclasses.dataclass +class HuggingFile: + """ + A file on Huggingface Hub + + Attributes: + repo_id (str): The Huggingface repository of a known file + filename (str): The path to the known file in the repository + show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter + """ + repo_id: str + filename: str + show_in_ui: Optional[bool] = True + + def __str__(self): + return split(self.filename)[-1] + + +def get_filename_list_with_downloadable(folder_name: str, known_huggingface_files: List[HuggingFile]) -> List[str]: + existing = frozenset(folder_paths.get_filename_list(folder_name)) + downloadable = frozenset(str(f) for f in known_huggingface_files) + return sorted(list(existing | downloadable)) + + +def get_or_download(folder_name: str, filename: str, known_huggingface_files: List[HuggingFile]) -> str: + path = folder_paths.get_full_path(folder_name, filename) + + if path is None: + try: + destination = folder_paths.get_folder_paths(folder_name)[0] + hugging_file = next(f for f in known_huggingface_files if str(f) == filename) + with comfy_tqdm(): + path = hf_hub_download(repo_id=hugging_file.repo_id, + filename=hugging_file.filename, + local_dir=destination, + resume_download=True) + except StopIteration: + pass + except Exception: + pass + return path + + +KNOWN_CHECKPOINTS = [ + HuggingFile("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors"), + HuggingFile("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors"), + HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors"), + HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0.safetensors", show_in_ui=False), + HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_b.safetensors"), + HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"), + HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stage_a.safetensors"), + HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors"), + HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned-emaonly.ckpt", show_in_ui=False), + HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned.ckpt", show_in_ui=False), + HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned.safetensors", show_in_ui=False), + # from https://github.com/comfyanonymous/ComfyUI_examples/tree/master/2_pass_txt2img + HuggingFile("stabilityai/stable-diffusion-2-1", "v2-1_768-ema-pruned.ckpt", show_in_ui=False), + HuggingFile("waifu-diffusion/wd-1-5-beta3", "wd-illusion-fp16.safetensors", show_in_ui=False), + HuggingFile("jomcs/NeverEnding_Dream-Feb19-2023", "CarDos Anime/cardosAnime_v10.safetensors", show_in_ui=False), + # from https://github.com/comfyanonymous/ComfyUI_examples/blob/master/area_composition/README.md + HuggingFile("ckpt/anything-v3.0", "Anything-V3.0.ckpt", show_in_ui=False), +] + +KNOWN_UNCLIP_CHECKPOINTS = [ + HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"), + HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-h.ckpt"), + HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-l.ckpt"), +] + +KNOWN_IMAGE_ONLY_CHECKPOINTS = [ + HuggingFile("stabilityai/stable-zero123", "stable_zero123.ckpt") +] + +KNOWN_UPSCALERS = [ + HuggingFile("lllyasviel/Annotators", "RealESRGAN_x4plus.pth") +] + +KNOWN_GLIGEN_MODELS = [ + HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned.safetensors"), + HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned_fp16.safetensors"), +] + +KNOWN_CLIP_VISION_MODELS = [ + HuggingFile("comfyanonymous/clip_vision_g", "clip_vision_g.safetensors") +] diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 3d8aad92a..4ef071b59 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -22,6 +22,8 @@ from .. import model_management from ..cli_args import args from ..cmd import folder_paths, latent_preview +from ..model_downloader import HuggingFile, get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \ + KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS from ..nodes.common import MAX_RESOLUTION from .. import controlnet @@ -497,7 +499,7 @@ class CheckpointLoader: @classmethod def INPUT_TYPES(s): return {"required": { "config_name": (folder_paths.get_filename_list("configs"),), - "ckpt_name": (folder_paths.get_filename_list("checkpoints"),)}} + "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),)}} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -505,13 +507,13 @@ class CheckpointLoader: def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): config_path = folder_paths.get_full_path("configs", config_name) - ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS) return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) class CheckpointLoaderSimple: @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), + return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE") FUNCTION = "load_checkpoint" @@ -519,7 +521,7 @@ class CheckpointLoaderSimple: CATEGORY = "loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): - ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS) out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out[:3] @@ -553,7 +555,7 @@ class DiffusersLoader: class unCLIPCheckpointLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), + return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_UNCLIP_CHECKPOINTS),), }} RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION") FUNCTION = "load_checkpoint" @@ -561,7 +563,7 @@ class unCLIPCheckpointLoader: CATEGORY = "loaders" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): - ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS) out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return out @@ -861,7 +863,7 @@ class DualCLIPLoader: class CLIPVisionLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"),), + return {"required": { "clip_name": (get_filename_list_with_downloadable("clip_vision", KNOWN_CLIP_VISION_MODELS),), }} RETURN_TYPES = ("CLIP_VISION",) FUNCTION = "load_clip" @@ -869,7 +871,7 @@ class CLIPVisionLoader: CATEGORY = "loaders" def load_clip(self, clip_name): - clip_path = folder_paths.get_full_path("clip_vision", clip_name) + clip_path = get_or_download("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS) clip_vision = clip_vision_module.load(clip_path) return (clip_vision,) @@ -956,7 +958,7 @@ class unCLIPConditioning: class GLIGENLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"),)}} + return {"required": { "gligen_name": (get_filename_list_with_downloadable("gligen", KNOWN_GLIGEN_MODELS),)}} RETURN_TYPES = ("GLIGEN",) FUNCTION = "load_gligen" @@ -964,7 +966,7 @@ class GLIGENLoader: CATEGORY = "loaders" def load_gligen(self, gligen_name): - gligen_path = folder_paths.get_full_path("gligen", gligen_name) + gligen_path = get_or_download("gligen", gligen_name, KNOWN_GLIGEN_MODELS) gligen = sd.load_gligen(gligen_path) return (gligen,) diff --git a/comfy/utils.py b/comfy/utils.py index 80695c0f6..9386e6342 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -7,6 +7,9 @@ from . import checkpoint_pickle import safetensors.torch import numpy as np from PIL import Image +from tqdm import tqdm +from contextlib import contextmanager + def load_torch_file(ckpt, safe_load=False, device=None): if device is None: @@ -486,3 +489,30 @@ class ProgressBar: def get_project_root() -> str: return os.path.join(os.path.dirname(__file__), "..") + + +@contextmanager +def comfy_tqdm(): + """ + Monky patches child calls to tqdm and sends the progress to the UI + :return: + """ + _original_init = tqdm.__init__ + _original_update = tqdm.update + try: + def __init(self, *args, **kwargs): + _original_init(self, *args, **kwargs) + self._progress_bar = ProgressBar(self.total) + + def __update(self, n=1): + assert self._progress_bar is not None + _original_update(self, n) + self._progress_bar.update(n) + + tqdm.__init__ = __init + tqdm.update = __update + yield + finally: + # Restore original tqdm + tqdm.__init__ = _original_init + tqdm.update = _original_update diff --git a/comfy_extras/nodes/nodes_upscale_model.py b/comfy_extras/nodes/nodes_upscale_model.py index 31bd420a3..077c88f9f 100644 --- a/comfy_extras/nodes/nodes_upscale_model.py +++ b/comfy_extras/nodes/nodes_upscale_model.py @@ -1,3 +1,4 @@ +from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download from ..chainner_models import model_loading from comfy import model_management import torch @@ -8,7 +9,7 @@ from comfy.cmd import folder_paths class UpscaleModelLoader: @classmethod def INPUT_TYPES(s): - return {"required": {"model_name": (folder_paths.get_filename_list("upscale_models"),), + return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models", KNOWN_UPSCALERS),), }} RETURN_TYPES = ("UPSCALE_MODEL",) @@ -17,7 +18,7 @@ class UpscaleModelLoader: CATEGORY = "loaders" def load_model(self, model_name): - model_path = folder_paths.get_full_path("upscale_models", model_name) + model_path = get_or_download("upscale_models", model_name, KNOWN_UPSCALERS) sd = utils.load_torch_file(model_path, safe_load=True) if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: sd = utils.state_dict_prefix_replace(sd, {"module.": ""}) diff --git a/comfy_extras/nodes/nodes_video_model.py b/comfy_extras/nodes/nodes_video_model.py index 4f54b5459..187495233 100644 --- a/comfy_extras/nodes/nodes_video_model.py +++ b/comfy_extras/nodes/nodes_video_model.py @@ -1,3 +1,4 @@ +from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_IMAGE_ONLY_CHECKPOINTS, get_or_download from comfy.nodes.common import MAX_RESOLUTION import torch import comfy.utils @@ -9,7 +10,7 @@ from . import nodes_model_merging class ImageOnlyCheckpointLoader: @classmethod def INPUT_TYPES(s): - return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), + return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_IMAGE_ONLY_CHECKPOINTS), ), }} RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") FUNCTION = "load_checkpoint" @@ -17,7 +18,7 @@ class ImageOnlyCheckpointLoader: CATEGORY = "loaders/video_models" def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): - ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) + ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_IMAGE_ONLY_CHECKPOINTS) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return (out[0], out[3], out[2]) diff --git a/requirements.txt b/requirements.txt index 3020cd3d0..ce9d90b2a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -30,4 +30,5 @@ ConfigArgParse aio-pika pyjwt[crypto] kornia>=0.7.1 -mpmath>=1.0,!=1.4.0a0 \ No newline at end of file +mpmath>=1.0,!=1.4.0a0 +huggingface_hub \ No newline at end of file