Download known models from HuggingFace

This commit is contained in:
Benjamin Berman 2024-03-11 00:15:06 -07:00
parent 175a50d7ba
commit 3c57ef831c
6 changed files with 145 additions and 15 deletions

95
comfy/model_downloader.py Normal file
View File

@ -0,0 +1,95 @@
import dataclasses
from typing import List, Optional
from huggingface_hub import hf_hub_download
from .cmd import folder_paths
from .utils import comfy_tqdm
from posixpath import split
@dataclasses.dataclass
class HuggingFile:
"""
A file on Huggingface Hub
Attributes:
repo_id (str): The Huggingface repository of a known file
filename (str): The path to the known file in the repository
show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter
"""
repo_id: str
filename: str
show_in_ui: Optional[bool] = True
def __str__(self):
return split(self.filename)[-1]
def get_filename_list_with_downloadable(folder_name: str, known_huggingface_files: List[HuggingFile]) -> List[str]:
existing = frozenset(folder_paths.get_filename_list(folder_name))
downloadable = frozenset(str(f) for f in known_huggingface_files)
return sorted(list(existing | downloadable))
def get_or_download(folder_name: str, filename: str, known_huggingface_files: List[HuggingFile]) -> str:
path = folder_paths.get_full_path(folder_name, filename)
if path is None:
try:
destination = folder_paths.get_folder_paths(folder_name)[0]
hugging_file = next(f for f in known_huggingface_files if str(f) == filename)
with comfy_tqdm():
path = hf_hub_download(repo_id=hugging_file.repo_id,
filename=hugging_file.filename,
local_dir=destination,
resume_download=True)
except StopIteration:
pass
except Exception:
pass
return path
KNOWN_CHECKPOINTS = [
HuggingFile("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors"),
HuggingFile("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors"),
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors"),
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0.safetensors", show_in_ui=False),
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_b.safetensors"),
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stage_a.safetensors"),
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors"),
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned-emaonly.ckpt", show_in_ui=False),
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned.ckpt", show_in_ui=False),
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned.safetensors", show_in_ui=False),
# from https://github.com/comfyanonymous/ComfyUI_examples/tree/master/2_pass_txt2img
HuggingFile("stabilityai/stable-diffusion-2-1", "v2-1_768-ema-pruned.ckpt", show_in_ui=False),
HuggingFile("waifu-diffusion/wd-1-5-beta3", "wd-illusion-fp16.safetensors", show_in_ui=False),
HuggingFile("jomcs/NeverEnding_Dream-Feb19-2023", "CarDos Anime/cardosAnime_v10.safetensors", show_in_ui=False),
# from https://github.com/comfyanonymous/ComfyUI_examples/blob/master/area_composition/README.md
HuggingFile("ckpt/anything-v3.0", "Anything-V3.0.ckpt", show_in_ui=False),
]
KNOWN_UNCLIP_CHECKPOINTS = [
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-h.ckpt"),
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-l.ckpt"),
]
KNOWN_IMAGE_ONLY_CHECKPOINTS = [
HuggingFile("stabilityai/stable-zero123", "stable_zero123.ckpt")
]
KNOWN_UPSCALERS = [
HuggingFile("lllyasviel/Annotators", "RealESRGAN_x4plus.pth")
]
KNOWN_GLIGEN_MODELS = [
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned.safetensors"),
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned_fp16.safetensors"),
]
KNOWN_CLIP_VISION_MODELS = [
HuggingFile("comfyanonymous/clip_vision_g", "clip_vision_g.safetensors")
]

View File

@ -22,6 +22,8 @@ from .. import model_management
from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..model_downloader import HuggingFile, get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS
from ..nodes.common import MAX_RESOLUTION
from .. import controlnet
@ -497,7 +499,7 @@ class CheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "config_name": (folder_paths.get_filename_list("configs"),),
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),)}}
"ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),)}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -505,13 +507,13 @@ class CheckpointLoader:
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CheckpointLoaderSimple:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint"
@ -519,7 +521,7 @@ class CheckpointLoaderSimple:
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3]
@ -553,7 +555,7 @@ class DiffusersLoader:
class unCLIPCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_UNCLIP_CHECKPOINTS),),
}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
FUNCTION = "load_checkpoint"
@ -561,7 +563,7 @@ class unCLIPCheckpointLoader:
CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS)
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out
@ -861,7 +863,7 @@ class DualCLIPLoader:
class CLIPVisionLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"),),
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip_vision", KNOWN_CLIP_VISION_MODELS),),
}}
RETURN_TYPES = ("CLIP_VISION",)
FUNCTION = "load_clip"
@ -869,7 +871,7 @@ class CLIPVisionLoader:
CATEGORY = "loaders"
def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
clip_path = get_or_download("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS)
clip_vision = clip_vision_module.load(clip_path)
return (clip_vision,)
@ -956,7 +958,7 @@ class unCLIPConditioning:
class GLIGENLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"),)}}
return {"required": { "gligen_name": (get_filename_list_with_downloadable("gligen", KNOWN_GLIGEN_MODELS),)}}
RETURN_TYPES = ("GLIGEN",)
FUNCTION = "load_gligen"
@ -964,7 +966,7 @@ class GLIGENLoader:
CATEGORY = "loaders"
def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
gligen_path = get_or_download("gligen", gligen_name, KNOWN_GLIGEN_MODELS)
gligen = sd.load_gligen(gligen_path)
return (gligen,)

View File

@ -7,6 +7,9 @@ from . import checkpoint_pickle
import safetensors.torch
import numpy as np
from PIL import Image
from tqdm import tqdm
from contextlib import contextmanager
def load_torch_file(ckpt, safe_load=False, device=None):
if device is None:
@ -486,3 +489,30 @@ class ProgressBar:
def get_project_root() -> str:
return os.path.join(os.path.dirname(__file__), "..")
@contextmanager
def comfy_tqdm():
"""
Monky patches child calls to tqdm and sends the progress to the UI
:return:
"""
_original_init = tqdm.__init__
_original_update = tqdm.update
try:
def __init(self, *args, **kwargs):
_original_init(self, *args, **kwargs)
self._progress_bar = ProgressBar(self.total)
def __update(self, n=1):
assert self._progress_bar is not None
_original_update(self, n)
self._progress_bar.update(n)
tqdm.__init__ = __init
tqdm.update = __update
yield
finally:
# Restore original tqdm
tqdm.__init__ = _original_init
tqdm.update = _original_update

View File

@ -1,3 +1,4 @@
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
from ..chainner_models import model_loading
from comfy import model_management
import torch
@ -8,7 +9,7 @@ from comfy.cmd import folder_paths
class UpscaleModelLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": {"model_name": (folder_paths.get_filename_list("upscale_models"),),
return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models", KNOWN_UPSCALERS),),
}}
RETURN_TYPES = ("UPSCALE_MODEL",)
@ -17,7 +18,7 @@ class UpscaleModelLoader:
CATEGORY = "loaders"
def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name)
model_path = get_or_download("upscale_models", model_name, KNOWN_UPSCALERS)
sd = utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = utils.state_dict_prefix_replace(sd, {"module.": ""})

View File

@ -1,3 +1,4 @@
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_IMAGE_ONLY_CHECKPOINTS, get_or_download
from comfy.nodes.common import MAX_RESOLUTION
import torch
import comfy.utils
@ -9,7 +10,7 @@ from . import nodes_model_merging
class ImageOnlyCheckpointLoader:
@classmethod
def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_IMAGE_ONLY_CHECKPOINTS), ),
}}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint"
@ -17,7 +18,7 @@ class ImageOnlyCheckpointLoader:
CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_IMAGE_ONLY_CHECKPOINTS)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2])

View File

@ -30,4 +30,5 @@ ConfigArgParse
aio-pika
pyjwt[crypto]
kornia>=0.7.1
mpmath>=1.0,!=1.4.0a0
mpmath>=1.0,!=1.4.0a0
huggingface_hub