mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-23 04:40:15 +08:00
Download known models from HuggingFace
This commit is contained in:
parent
175a50d7ba
commit
3c57ef831c
95
comfy/model_downloader.py
Normal file
95
comfy/model_downloader.py
Normal file
@ -0,0 +1,95 @@
|
|||||||
|
import dataclasses
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
|
from .cmd import folder_paths
|
||||||
|
from .utils import comfy_tqdm
|
||||||
|
from posixpath import split
|
||||||
|
|
||||||
|
|
||||||
|
@dataclasses.dataclass
|
||||||
|
class HuggingFile:
|
||||||
|
"""
|
||||||
|
A file on Huggingface Hub
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
repo_id (str): The Huggingface repository of a known file
|
||||||
|
filename (str): The path to the known file in the repository
|
||||||
|
show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter
|
||||||
|
"""
|
||||||
|
repo_id: str
|
||||||
|
filename: str
|
||||||
|
show_in_ui: Optional[bool] = True
|
||||||
|
|
||||||
|
def __str__(self):
|
||||||
|
return split(self.filename)[-1]
|
||||||
|
|
||||||
|
|
||||||
|
def get_filename_list_with_downloadable(folder_name: str, known_huggingface_files: List[HuggingFile]) -> List[str]:
|
||||||
|
existing = frozenset(folder_paths.get_filename_list(folder_name))
|
||||||
|
downloadable = frozenset(str(f) for f in known_huggingface_files)
|
||||||
|
return sorted(list(existing | downloadable))
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_download(folder_name: str, filename: str, known_huggingface_files: List[HuggingFile]) -> str:
|
||||||
|
path = folder_paths.get_full_path(folder_name, filename)
|
||||||
|
|
||||||
|
if path is None:
|
||||||
|
try:
|
||||||
|
destination = folder_paths.get_folder_paths(folder_name)[0]
|
||||||
|
hugging_file = next(f for f in known_huggingface_files if str(f) == filename)
|
||||||
|
with comfy_tqdm():
|
||||||
|
path = hf_hub_download(repo_id=hugging_file.repo_id,
|
||||||
|
filename=hugging_file.filename,
|
||||||
|
local_dir=destination,
|
||||||
|
resume_download=True)
|
||||||
|
except StopIteration:
|
||||||
|
pass
|
||||||
|
except Exception:
|
||||||
|
pass
|
||||||
|
return path
|
||||||
|
|
||||||
|
|
||||||
|
KNOWN_CHECKPOINTS = [
|
||||||
|
HuggingFile("stabilityai/stable-diffusion-xl-base-1.0", "sd_xl_base_1.0.safetensors"),
|
||||||
|
HuggingFile("stabilityai/stable-diffusion-xl-refiner-1.0", "sd_xl_refiner_1.0.safetensors"),
|
||||||
|
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0_fp16.safetensors"),
|
||||||
|
HuggingFile("stabilityai/sdxl-turbo", "sd_xl_turbo_1.0.safetensors", show_in_ui=False),
|
||||||
|
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_b.safetensors"),
|
||||||
|
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
|
||||||
|
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stage_a.safetensors"),
|
||||||
|
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned-emaonly.safetensors"),
|
||||||
|
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned-emaonly.ckpt", show_in_ui=False),
|
||||||
|
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned.ckpt", show_in_ui=False),
|
||||||
|
HuggingFile("runwayml/stable-diffusion-v1-5", "v1-5-pruned.safetensors", show_in_ui=False),
|
||||||
|
# from https://github.com/comfyanonymous/ComfyUI_examples/tree/master/2_pass_txt2img
|
||||||
|
HuggingFile("stabilityai/stable-diffusion-2-1", "v2-1_768-ema-pruned.ckpt", show_in_ui=False),
|
||||||
|
HuggingFile("waifu-diffusion/wd-1-5-beta3", "wd-illusion-fp16.safetensors", show_in_ui=False),
|
||||||
|
HuggingFile("jomcs/NeverEnding_Dream-Feb19-2023", "CarDos Anime/cardosAnime_v10.safetensors", show_in_ui=False),
|
||||||
|
# from https://github.com/comfyanonymous/ComfyUI_examples/blob/master/area_composition/README.md
|
||||||
|
HuggingFile("ckpt/anything-v3.0", "Anything-V3.0.ckpt", show_in_ui=False),
|
||||||
|
]
|
||||||
|
|
||||||
|
KNOWN_UNCLIP_CHECKPOINTS = [
|
||||||
|
HuggingFile("stabilityai/stable-cascade", "comfyui_checkpoints/stable_cascade_stage_c.safetensors"),
|
||||||
|
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-h.ckpt"),
|
||||||
|
HuggingFile("stabilityai/stable-diffusion-2-1-unclip", "sd21-unclip-l.ckpt"),
|
||||||
|
]
|
||||||
|
|
||||||
|
KNOWN_IMAGE_ONLY_CHECKPOINTS = [
|
||||||
|
HuggingFile("stabilityai/stable-zero123", "stable_zero123.ckpt")
|
||||||
|
]
|
||||||
|
|
||||||
|
KNOWN_UPSCALERS = [
|
||||||
|
HuggingFile("lllyasviel/Annotators", "RealESRGAN_x4plus.pth")
|
||||||
|
]
|
||||||
|
|
||||||
|
KNOWN_GLIGEN_MODELS = [
|
||||||
|
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned.safetensors"),
|
||||||
|
HuggingFile("comfyanonymous/GLIGEN_pruned_safetensors", "gligen_sd14_textbox_pruned_fp16.safetensors"),
|
||||||
|
]
|
||||||
|
|
||||||
|
KNOWN_CLIP_VISION_MODELS = [
|
||||||
|
HuggingFile("comfyanonymous/clip_vision_g", "clip_vision_g.safetensors")
|
||||||
|
]
|
||||||
@ -22,6 +22,8 @@ from .. import model_management
|
|||||||
from ..cli_args import args
|
from ..cli_args import args
|
||||||
|
|
||||||
from ..cmd import folder_paths, latent_preview
|
from ..cmd import folder_paths, latent_preview
|
||||||
|
from ..model_downloader import HuggingFile, get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, \
|
||||||
|
KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS
|
||||||
from ..nodes.common import MAX_RESOLUTION
|
from ..nodes.common import MAX_RESOLUTION
|
||||||
from .. import controlnet
|
from .. import controlnet
|
||||||
|
|
||||||
@ -497,7 +499,7 @@ class CheckpointLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "config_name": (folder_paths.get_filename_list("configs"),),
|
return {"required": { "config_name": (folder_paths.get_filename_list("configs"),),
|
||||||
"ckpt_name": (folder_paths.get_filename_list("checkpoints"),)}}
|
"ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),)}}
|
||||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
|
|
||||||
@ -505,13 +507,13 @@ class CheckpointLoader:
|
|||||||
|
|
||||||
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, config_name, ckpt_name, output_vae=True, output_clip=True):
|
||||||
config_path = folder_paths.get_full_path("configs", config_name)
|
config_path = folder_paths.get_full_path("configs", config_name)
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
||||||
return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
return sd.load_checkpoint(config_path, ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
|
|
||||||
class CheckpointLoaderSimple:
|
class CheckpointLoaderSimple:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
|
return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_CHECKPOINTS),),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE")
|
||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
@ -519,7 +521,7 @@ class CheckpointLoaderSimple:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_CHECKPOINTS)
|
||||||
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out[:3]
|
return out[:3]
|
||||||
|
|
||||||
@ -553,7 +555,7 @@ class DiffusersLoader:
|
|||||||
class unCLIPCheckpointLoader:
|
class unCLIPCheckpointLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"),),
|
return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_UNCLIP_CHECKPOINTS),),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
|
RETURN_TYPES = ("MODEL", "CLIP", "VAE", "CLIP_VISION")
|
||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
@ -561,7 +563,7 @@ class unCLIPCheckpointLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_UNCLIP_CHECKPOINTS)
|
||||||
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=True, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
@ -861,7 +863,7 @@ class DualCLIPLoader:
|
|||||||
class CLIPVisionLoader:
|
class CLIPVisionLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "clip_name": (folder_paths.get_filename_list("clip_vision"),),
|
return {"required": { "clip_name": (get_filename_list_with_downloadable("clip_vision", KNOWN_CLIP_VISION_MODELS),),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("CLIP_VISION",)
|
RETURN_TYPES = ("CLIP_VISION",)
|
||||||
FUNCTION = "load_clip"
|
FUNCTION = "load_clip"
|
||||||
@ -869,7 +871,7 @@ class CLIPVisionLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_clip(self, clip_name):
|
def load_clip(self, clip_name):
|
||||||
clip_path = folder_paths.get_full_path("clip_vision", clip_name)
|
clip_path = get_or_download("clip_vision", clip_name, KNOWN_CLIP_VISION_MODELS)
|
||||||
clip_vision = clip_vision_module.load(clip_path)
|
clip_vision = clip_vision_module.load(clip_path)
|
||||||
return (clip_vision,)
|
return (clip_vision,)
|
||||||
|
|
||||||
@ -956,7 +958,7 @@ class unCLIPConditioning:
|
|||||||
class GLIGENLoader:
|
class GLIGENLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "gligen_name": (folder_paths.get_filename_list("gligen"),)}}
|
return {"required": { "gligen_name": (get_filename_list_with_downloadable("gligen", KNOWN_GLIGEN_MODELS),)}}
|
||||||
|
|
||||||
RETURN_TYPES = ("GLIGEN",)
|
RETURN_TYPES = ("GLIGEN",)
|
||||||
FUNCTION = "load_gligen"
|
FUNCTION = "load_gligen"
|
||||||
@ -964,7 +966,7 @@ class GLIGENLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_gligen(self, gligen_name):
|
def load_gligen(self, gligen_name):
|
||||||
gligen_path = folder_paths.get_full_path("gligen", gligen_name)
|
gligen_path = get_or_download("gligen", gligen_name, KNOWN_GLIGEN_MODELS)
|
||||||
gligen = sd.load_gligen(gligen_path)
|
gligen = sd.load_gligen(gligen_path)
|
||||||
return (gligen,)
|
return (gligen,)
|
||||||
|
|
||||||
|
|||||||
@ -7,6 +7,9 @@ from . import checkpoint_pickle
|
|||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from tqdm import tqdm
|
||||||
|
from contextlib import contextmanager
|
||||||
|
|
||||||
|
|
||||||
def load_torch_file(ckpt, safe_load=False, device=None):
|
def load_torch_file(ckpt, safe_load=False, device=None):
|
||||||
if device is None:
|
if device is None:
|
||||||
@ -486,3 +489,30 @@ class ProgressBar:
|
|||||||
|
|
||||||
def get_project_root() -> str:
|
def get_project_root() -> str:
|
||||||
return os.path.join(os.path.dirname(__file__), "..")
|
return os.path.join(os.path.dirname(__file__), "..")
|
||||||
|
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def comfy_tqdm():
|
||||||
|
"""
|
||||||
|
Monky patches child calls to tqdm and sends the progress to the UI
|
||||||
|
:return:
|
||||||
|
"""
|
||||||
|
_original_init = tqdm.__init__
|
||||||
|
_original_update = tqdm.update
|
||||||
|
try:
|
||||||
|
def __init(self, *args, **kwargs):
|
||||||
|
_original_init(self, *args, **kwargs)
|
||||||
|
self._progress_bar = ProgressBar(self.total)
|
||||||
|
|
||||||
|
def __update(self, n=1):
|
||||||
|
assert self._progress_bar is not None
|
||||||
|
_original_update(self, n)
|
||||||
|
self._progress_bar.update(n)
|
||||||
|
|
||||||
|
tqdm.__init__ = __init
|
||||||
|
tqdm.update = __update
|
||||||
|
yield
|
||||||
|
finally:
|
||||||
|
# Restore original tqdm
|
||||||
|
tqdm.__init__ = _original_init
|
||||||
|
tqdm.update = _original_update
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_UPSCALERS, get_or_download
|
||||||
from ..chainner_models import model_loading
|
from ..chainner_models import model_loading
|
||||||
from comfy import model_management
|
from comfy import model_management
|
||||||
import torch
|
import torch
|
||||||
@ -8,7 +9,7 @@ from comfy.cmd import folder_paths
|
|||||||
class UpscaleModelLoader:
|
class UpscaleModelLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"model_name": (folder_paths.get_filename_list("upscale_models"),),
|
return {"required": {"model_name": (get_filename_list_with_downloadable("upscale_models", KNOWN_UPSCALERS),),
|
||||||
}}
|
}}
|
||||||
|
|
||||||
RETURN_TYPES = ("UPSCALE_MODEL",)
|
RETURN_TYPES = ("UPSCALE_MODEL",)
|
||||||
@ -17,7 +18,7 @@ class UpscaleModelLoader:
|
|||||||
CATEGORY = "loaders"
|
CATEGORY = "loaders"
|
||||||
|
|
||||||
def load_model(self, model_name):
|
def load_model(self, model_name):
|
||||||
model_path = folder_paths.get_full_path("upscale_models", model_name)
|
model_path = get_or_download("upscale_models", model_name, KNOWN_UPSCALERS)
|
||||||
sd = utils.load_torch_file(model_path, safe_load=True)
|
sd = utils.load_torch_file(model_path, safe_load=True)
|
||||||
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
if "module.layers.0.residual_group.blocks.0.norm1.weight" in sd:
|
||||||
sd = utils.state_dict_prefix_replace(sd, {"module.": ""})
|
sd = utils.state_dict_prefix_replace(sd, {"module.": ""})
|
||||||
|
|||||||
@ -1,3 +1,4 @@
|
|||||||
|
from comfy.model_downloader import get_filename_list_with_downloadable, KNOWN_IMAGE_ONLY_CHECKPOINTS, get_or_download
|
||||||
from comfy.nodes.common import MAX_RESOLUTION
|
from comfy.nodes.common import MAX_RESOLUTION
|
||||||
import torch
|
import torch
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
@ -9,7 +10,7 @@ from . import nodes_model_merging
|
|||||||
class ImageOnlyCheckpointLoader:
|
class ImageOnlyCheckpointLoader:
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "ckpt_name": (folder_paths.get_filename_list("checkpoints"), ),
|
return {"required": { "ckpt_name": (get_filename_list_with_downloadable("checkpoints", KNOWN_IMAGE_ONLY_CHECKPOINTS), ),
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
RETURN_TYPES = ("MODEL", "CLIP_VISION", "VAE")
|
||||||
FUNCTION = "load_checkpoint"
|
FUNCTION = "load_checkpoint"
|
||||||
@ -17,7 +18,7 @@ class ImageOnlyCheckpointLoader:
|
|||||||
CATEGORY = "loaders/video_models"
|
CATEGORY = "loaders/video_models"
|
||||||
|
|
||||||
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
def load_checkpoint(self, ckpt_name, output_vae=True, output_clip=True):
|
||||||
ckpt_path = folder_paths.get_full_path("checkpoints", ckpt_name)
|
ckpt_path = get_or_download("checkpoints", ckpt_name, KNOWN_IMAGE_ONLY_CHECKPOINTS)
|
||||||
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
out = comfy.sd.load_checkpoint_guess_config(ckpt_path, output_vae=True, output_clip=False, output_clipvision=True, embedding_directory=folder_paths.get_folder_paths("embeddings"))
|
||||||
return (out[0], out[3], out[2])
|
return (out[0], out[3], out[2])
|
||||||
|
|
||||||
|
|||||||
@ -30,4 +30,5 @@ ConfigArgParse
|
|||||||
aio-pika
|
aio-pika
|
||||||
pyjwt[crypto]
|
pyjwt[crypto]
|
||||||
kornia>=0.7.1
|
kornia>=0.7.1
|
||||||
mpmath>=1.0,!=1.4.0a0
|
mpmath>=1.0,!=1.4.0a0
|
||||||
|
huggingface_hub
|
||||||
Loading…
Reference in New Issue
Block a user