From 3e002b9f72a814158136f637c6ece426ef95a236 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Thu, 4 Apr 2024 23:40:29 -0700 Subject: [PATCH] Fix string joining node, improve model downloading --- comfy/model_downloader.py | 66 +++++++++++++++++++++++----- comfy/model_downloader_types.py | 4 +- comfy_extras/nodes/nodes_open_api.py | 2 + 3 files changed, 59 insertions(+), 13 deletions(-) diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index fb77430a4..23e2d29a7 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -6,14 +6,17 @@ from itertools import chain from os.path import join from typing import List, Any, Optional, Union +import tqdm from huggingface_hub import hf_hub_download from requests import Session +from safetensors import safe_open +from safetensors.torch import save_file -from .cmd import folder_paths -from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_ -from .interruption import InterruptProcessingException -from .utils import ProgressBar, comfy_tqdm from .cli_args import args +from .cmd import folder_paths +from .interruption import InterruptProcessingException +from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_ +from .utils import ProgressBar, comfy_tqdm _session = Session() @@ -30,7 +33,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi if path is None and not args.disable_known_models: try: # todo: should this be the first or last path? - destination = folder_paths.get_folder_paths(folder_name)[0] + this_model_directory = folder_paths.get_folder_paths(folder_name)[0] known_file: Optional[HuggingFile | CivitFile] = None for candidate in known_files: if str(candidate) == filename or candidate.filename == filename or filename in candidate.alternate_filenames or filename == candidate.save_with_filename: @@ -40,19 +43,57 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi return path with comfy_tqdm(): if isinstance(known_file, HuggingFile): - path = hf_hub_download(repo_id=known_file.repo_id, - filename=known_file.filename, - local_dir=destination, - resume_download=True) if known_file.save_with_filename is not None: linked_filename = known_file.save_with_filename elif os.path.basename(known_file.filename) != known_file.filename: linked_filename = os.path.basename(known_file.filename) else: linked_filename = None + + if linked_filename is not None and os.path.dirname(known_file.filename) == "": + # if the known file has an overridden linked name, save it into a repo_id sub directory + # this deals with situations like + # jschoormans/controlnet-densepose-sdxl repo having diffusion_pytorch_model.safetensors + # it should be saved to controlnet-densepose-sdxl.safetensors + # since there are a bajillion diffusion_pytorch_model.safetensors, it should be downloaded by hf into jschoormans/controlnet-densepose-sdxl/diffusion_pytorch_model.safetensors + # then linked to the local folder to controlnet-densepose-sdxl.safetensors or some other canonical name + hf_destination_dir = os.path.join(this_model_directory, known_file.repo_id) + else: + hf_destination_dir = this_model_directory + + # converted 16 bit files should be skipped + path = os.path.join(hf_destination_dir, known_file.filename) + try: + file_size = os.stat(path, follow_symlinks=True).st_size if os.path.isfile(path) else None + except: + file_size = None + if os.path.isfile(path) and file_size == known_file.size: + return path + + path = hf_hub_download(repo_id=known_file.repo_id, + filename=known_file.filename, + local_dir=hf_destination_dir, + resume_download=True) + + if known_file.convert_to_16_bit and file_size is not None and file_size != 0: + tensors = {} + with safe_open(path, framework="pt") as f: + with tqdm.tqdm(total=len(f.keys())) as pb: + for k in f.keys(): + x = f.get_tensor(k) + tensors[k] = x.half() + del x + pb.update() + + save_file(tensors, path) + + for _, v in tensors.items(): + del v + logging.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}") + try: if linked_filename is not None: - os.symlink(os.path.join(destination, known_file.filename), linked_filename) + os.symlink(os.path.join(hf_destination_dir, known_file.filename), os.path.join(this_model_directory, linked_filename)) except Exception as exc_info: logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info) else: @@ -75,7 +116,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi if url is None: logging.warning(f"Could not retrieve file {str(known_file)}") else: - destination_with_filename = join(destination, save_filename) + destination_with_filename = join(this_model_directory, save_filename) try: with _session.get(url, stream=True, allow_redirects=True) as response: @@ -154,7 +195,7 @@ KNOWN_LORAS = [ ] KNOWN_CONTROLNETS = [ - HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "OpenPoseXL2.safetensors"), + HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "OpenPoseXL2.safetensors", convert_to_16_bit=True, size=2502139104), HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "control-lora-openposeXL2-rank256.safetensors"), HuggingFile("comfyanonymous/ControlNet-v1-1_fp16_safetensors", "control_lora_rank128_v11e_sd15_ip2p_fp16.safetensors"), HuggingFile("comfyanonymous/ControlNet-v1-1_fp16_safetensors", "control_lora_rank128_v11e_sd15_shuffle_fp16.safetensors"), @@ -226,6 +267,7 @@ KNOWN_CONTROLNETS = [ HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_xl_sketch.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "thibaud_xl_openpose.safetensors"), HuggingFile("lllyasviel/sd_control_collection", "thibaud_xl_openpose_256lora.safetensors"), + HuggingFile("jschoormans/controlnet-densepose-sdxl", "diffusion_pytorch_model.safetensors", save_with_filename="controlnet-densepose-sdxl.safetensors", convert_to_16_bit=True, size=2502139104), ] KNOWN_DIFF_CONTROLNETS = [ diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index 554a325e5..a3e18a652 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -49,9 +49,11 @@ class HuggingFile: save_with_filename: Optional[str] = None alternate_filenames: List[str] = dataclasses.field(default_factory=list) show_in_ui: Optional[bool] = True + convert_to_16_bit: Optional[bool] = False + size: Optional[int] = None def __str__(self): - return split(self.filename)[-1] + return self.save_with_filename or split(self.filename)[-1] class CivitStats(TypedDict): diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index b5bb8e63e..c9ee728da 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -276,6 +276,7 @@ class StringJoin(CustomNode): RETURN_TYPES = ("STRING",) CATEGORY = "api/openapi" + FUNCTION = "execute" def execute(self, separator: str = "_", *args: str, **kwargs) -> ValidatedNodeResult: sorted_keys = natsorted(kwargs.keys()) @@ -323,6 +324,7 @@ class UriFormat(CustomNode): } RETURN_TYPES = ("URIS", "URIS") + RETURN_NAMES = ("URIS (FILES)", "URIS (META)") FUNCTION = "execute" CATEGORY = "api/openapi"