Download known models from HuggingFace

This commit is contained in:
Benjamin Berman 2024-03-11 00:15:06 -07:00
parent 175a50d7ba
commit 3c57ef831c
6 changed files with 145 additions and 15 deletions

95
comfy/model_downloader.py Normal file
View File

@ -0,0 +1,95 @@
import dataclasses
from typing import List, Optional
from huggingface_hub import hf_hub_download
from .cmd import folder_paths
from .utils import comfy_tqdm
from posixpath import split
@dataclasses.dataclass
class HuggingFile:
"""
A file on Huggingface Hub
Attributes:
repo_id (str): The Huggingface repository of a known file
filename (str): The path to the known file in the repository
show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter
"""
repo_id: str
filename: str
show_in_ui: Optional[bool] = True
def __str__(self):
return split(self.filename)[-1]
def get_filename_list_with_downloadable(folder_name: str, known_huggingface_files: List[HuggingFile]) -> List[str]:
existing = frozenset(folder_paths.get_filename_list(folder_name))
downloadable = frozenset(str(f) for f in known_huggingface_files)
return sorted(list(existing | downloadable))
def get_or_download(folder_name: str, filename: str, known_huggingface_files: List[HuggingFile]) -> str:
path = folder_paths.get_full_path(folder_name, filename)
if path is None:
try:
destination = folder_paths.get_folder_paths(folder_name)[0]
hugging_file = next(f for f in known_huggingface_files if str(f) == filename)
with comfy_tqdm():
path = hf_hub_download(repo_id=hugging_file.repo_id,
filename=hugging_file.filename,
local_dir=destination,
resume_download=True)
except StopIteration:
pass
except Exception:
pass
return path
KNOWN_CHECKPOINTS = [
HuggingFile("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors"),
HuggingFile("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors"),
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors"),
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0.safetensors", show_in_ui=False),
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_b.safetensors"),
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stage_a.safetensors"),
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors"),
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned-emaonly.ckpt", show_in_ui=False),
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned.ckpt", show_in_ui=False),
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned.safetensors", show_in_ui=False),
# from https://github.com/comfyanonymous/ComfyUI_examples/tree/master/2_pass_txt2img
HuggingFile("stabilityai/stable-diffusion-2-1", "v2-1_768-ema-pruned.ckpt", show_in_ui=False),
HuggingFile("waifu-diffusion/wd-1-5-beta3", "wd-illusion-fp16.safetensors", show_in_ui=False),
HuggingFile("jomcs/NeverEnding_Dream-Feb19-2023", "CarDos Anime/cardosAnime_v10.safetensors", show_in_ui=False),
# from https://github.com/comfyanonymous/ComfyUI_examples/blob/master/area_composition/README.md
HuggingFile("ckpt/anything-v3.0", "Anything-V3.0.ckpt", show_in_ui=False),
]
KNOWN_UNCLIP_CHECKPOINTS = [
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-h.ckpt"),
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-l.ckpt"),
]
KNOWN_IMAGE_ONLY_CHECKPOINTS = [
HuggingFile("stabilityai/stable-zero123", "stable_zero123.ckpt")
]
KNOWN_UPSCALERS = [
HuggingFile("lllyasviel/Annotators", "RealESRGAN_x4plus.pth")
]
KNOWN_GLIGEN_MODELS = [
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned.safetensors"),
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned_fp16.safetensors"),
]
KNOWN_CLIP_VISION_MODELS = [
HuggingFile("comfyanonymous/clip_vision_g", "clip_vision_g.safetensors")
]

View File

@ -22,6 +22,8 @@ from .. import model_management
from ..cli_args import args from ..cli_args import args
from ..cmd import folder_paths, latent_preview from ..cmd import folder_paths, latent_preview
from ..model_downloader import HuggingFile, get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS
from ..nodes.common import MAX_RESOLUTION from ..nodes.common import MAX_RESOLUTION
from .. import controlnet from .. import controlnet
@ -497,7 +499,7 @@ class CheckpointLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "config_name": (folder_paths.get_filename_list("configs"),), return {"required": { "config_name": (folder_paths.get_filename_list("configs"),),
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),)}} "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),)}}
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
@ -505,13 +507,13 @@ class CheckpointLoader:
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
config_path = folder_paths.get_full_path("configs", config_name) config_path = folder_paths.get_full_path("configs", config_name)
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
class CheckpointLoaderSimple: class CheckpointLoaderSimple:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),),
}} }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE") RETURN_TYPES = ("MODEL", "CLIP", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
@ -519,7 +521,7 @@ class CheckpointLoaderSimple:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out[:3] return out[:3]
@ -553,7 +555,7 @@ class DiffusersLoader:
class unCLIPCheckpointLoader: class unCLIPCheckpointLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),), return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_UNCLIP_CHECKPOINTS),),
}} }}
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION") RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
@ -561,7 +563,7 @@ class unCLIPCheckpointLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS)
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return out return out
@ -861,7 +863,7 @@ class DualCLIPLoader:
class CLIPVisionLoader: class CLIPVisionLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"),), return {"required": { "clip_name": (get_filename_list_with_downloadable("clip_vision", KNOWN_CLIP_VISION_MODELS),),
}} }}
RETURN_TYPES = ("CLIP_VISION",) RETURN_TYPES = ("CLIP_VISION",)
FUNCTION = "load_clip" FUNCTION = "load_clip"
@ -869,7 +871,7 @@ class CLIPVisionLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_clip(self, clip_name): def load_clip(self, clip_name):
clip_path = folder_paths.get_full_path("clip_vision", clip_name) clip_path = get_or_download("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS)
clip_vision = clip_vision_module.load(clip_path) clip_vision = clip_vision_module.load(clip_path)
return (clip_vision,) return (clip_vision,)
@ -956,7 +958,7 @@ class unCLIPConditioning:
class GLIGENLoader: class GLIGENLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"),)}} return {"required": { "gligen_name": (get_filename_list_with_downloadable("gligen", KNOWN_GLIGEN_MODELS),)}}
RETURN_TYPES = ("GLIGEN",) RETURN_TYPES = ("GLIGEN",)
FUNCTION = "load_gligen" FUNCTION = "load_gligen"
@ -964,7 +966,7 @@ class GLIGENLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_gligen(self, gligen_name): def load_gligen(self, gligen_name):
gligen_path = folder_paths.get_full_path("gligen", gligen_name) gligen_path = get_or_download("gligen", gligen_name, KNOWN_GLIGEN_MODELS)
gligen = sd.load_gligen(gligen_path) gligen = sd.load_gligen(gligen_path)
return (gligen,) return (gligen,)

View File

@ -7,6 +7,9 @@ from . import checkpoint_pickle
import safetensors.torch import safetensors.torch
import numpy as np import numpy as np
from PIL import Image from PIL import Image
from tqdm import tqdm
from contextlib import contextmanager
def load_torch_file(ckpt, safe_load=False, device=None): def load_torch_file(ckpt, safe_load=False, device=None):
if device is None: if device is None:
@ -486,3 +489,30 @@ class ProgressBar:
def get_project_root() -> str: def get_project_root() -> str:
return os.path.join(os.path.dirname(__file__), "..") return os.path.join(os.path.dirname(__file__), "..")
@contextmanager
def comfy_tqdm():
"""
Monky patches child calls to tqdm and sends the progress to the UI
:return:
"""
_original_init = tqdm.__init__
_original_update = tqdm.update
try:
def __init(self, *args, **kwargs):
_original_init(self, *args, **kwargs)
self._progress_bar = ProgressBar(self.total)
def __update(self, n=1):
assert self._progress_bar is not None
_original_update(self, n)
self._progress_bar.update(n)
tqdm.__init__ = __init
tqdm.update = __update
yield
finally:
# Restore original tqdm
tqdm.__init__ = _original_init
tqdm.update = _original_update

View File

@ -1,3 +1,4 @@
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
from ..chainner_models import model_loading from ..chainner_models import model_loading
from comfy import model_management from comfy import model_management
import torch import torch
@ -8,7 +9,7 @@ from comfy.cmd import folder_paths
class UpscaleModelLoader: class UpscaleModelLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"model_name": (folder_paths.get_filename_list("upscale_models"),), return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models", KNOWN_UPSCALERS),),
}} }}
RETURN_TYPES = ("UPSCALE_MODEL",) RETURN_TYPES = ("UPSCALE_MODEL",)
@ -17,7 +18,7 @@ class UpscaleModelLoader:
CATEGORY = "loaders" CATEGORY = "loaders"
def load_model(self, model_name): def load_model(self, model_name):
model_path = folder_paths.get_full_path("upscale_models", model_name) model_path = get_or_download("upscale_models", model_name, KNOWN_UPSCALERS)
sd = utils.load_torch_file(model_path, safe_load=True) sd = utils.load_torch_file(model_path, safe_load=True)
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd: if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
sd = utils.state_dict_prefix_replace(sd, {"module.": ""}) sd = utils.state_dict_prefix_replace(sd, {"module.": ""})

View File

@ -1,3 +1,4 @@
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_IMAGE_ONLY_CHECKPOINTS, get_or_download
from comfy.nodes.common import MAX_RESOLUTION from comfy.nodes.common import MAX_RESOLUTION
import torch import torch
import comfy.utils import comfy.utils
@ -9,7 +10,7 @@ from . import nodes_model_merging
class ImageOnlyCheckpointLoader: class ImageOnlyCheckpointLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ), return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_IMAGE_ONLY_CHECKPOINTS), ),
}} }}
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE") RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
FUNCTION = "load_checkpoint" FUNCTION = "load_checkpoint"
@ -17,7 +18,7 @@ class ImageOnlyCheckpointLoader:
CATEGORY = "loaders/video_models" CATEGORY = "loaders/video_models"
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True): def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name) ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_IMAGE_ONLY_CHECKPOINTS)
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings")) out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
return (out[0], out[3], out[2]) return (out[0], out[3], out[2])

View File

@ -30,4 +30,5 @@ ConfigArgParse
aio-pika aio-pika
pyjwt[crypto] pyjwt[crypto]
kornia>=0.7.1 kornia>=0.7.1
mpmath>=1.0,!=1.4.0a0 mpmath>=1.0,!=1.4.0a0
huggingface_hub