Fix string joining node, improve model downloading

This commit is contained in:
doctorpangloss 2024-04-04 23:40:29 -07:00
parent abb952ad77
commit 3e002b9f72
3 changed files with 59 additions and 13 deletions

View File

@ -6,14 +6,17 @@ from itertools import chain
from os.path import join
from typing import List, Any, Optional, Union
import tqdm
from huggingface_hub import hf_hub_download
from requests import Session
from safetensors import safe_open
from safetensors.torch import save_file
from .cmd import folder_paths
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_
from .interruption import InterruptProcessingException
from .utils import ProgressBar, comfy_tqdm
from .cli_args import args
from .cmd import folder_paths
from .interruption import InterruptProcessingException
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_
from .utils import ProgressBar, comfy_tqdm
_session = Session()
@ -30,7 +33,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
if path is None and not args.disable_known_models:
try:
# todo: should this be the first or last path?
destination = folder_paths.get_folder_paths(folder_name)[0]
this_model_directory = folder_paths.get_folder_paths(folder_name)[0]
known_file: Optional[HuggingFile | CivitFile] = None
for candidate in known_files:
if str(candidate) == filename or candidate.filename == filename or filename in candidate.alternate_filenames or filename == candidate.save_with_filename:
@ -40,19 +43,57 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
return path
with comfy_tqdm():
if isinstance(known_file, HuggingFile):
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
local_dir=destination,
resume_download=True)
if known_file.save_with_filename is not None:
linked_filename = known_file.save_with_filename
elif os.path.basename(known_file.filename) != known_file.filename:
linked_filename = os.path.basename(known_file.filename)
else:
linked_filename = None
if linked_filename is not None and os.path.dirname(known_file.filename) == "":
# if the known file has an overridden linked name, save it into a repo_id sub directory
# this deals with situations like
# jschoormans/controlnet-densepose-sdxl repo having diffusion_pytorch_model.safetensors
# it should be saved to controlnet-densepose-sdxl.safetensors
# since there are a bajillion diffusion_pytorch_model.safetensors, it should be downloaded by hf into jschoormans/controlnet-densepose-sdxl/diffusion_pytorch_model.safetensors
# then linked to the local folder to controlnet-densepose-sdxl.safetensors or some other canonical name
hf_destination_dir = os.path.join(this_model_directory, known_file.repo_id)
else:
hf_destination_dir = this_model_directory
# converted 16 bit files should be skipped
path = os.path.join(hf_destination_dir, known_file.filename)
try:
file_size = os.stat(path, follow_symlinks=True).st_size if os.path.isfile(path) else None
except:
file_size = None
if os.path.isfile(path) and file_size == known_file.size:
return path
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
local_dir=hf_destination_dir,
resume_download=True)
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:
tensors = {}
with safe_open(path, framework="pt") as f:
with tqdm.tqdm(total=len(f.keys())) as pb:
for k in f.keys():
x = f.get_tensor(k)
tensors[k] = x.half()
del x
pb.update()
save_file(tensors, path)
for _, v in tensors.items():
del v
logging.info(f"Converted {path} to 16 bit, size is {os.stat(path, follow_symlinks=True).st_size}")
try:
if linked_filename is not None:
os.symlink(os.path.join(destination, known_file.filename), linked_filename)
os.symlink(os.path.join(hf_destination_dir, known_file.filename), os.path.join(this_model_directory, linked_filename))
except Exception as exc_info:
logging.error(f"Failed to link file with alternative download save name in a way that is compatible with Hugging Face caching {repr(known_file)}", exc_info=exc_info)
else:
@ -75,7 +116,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
if url is None:
logging.warning(f"Could not retrieve file {str(known_file)}")
else:
destination_with_filename = join(destination, save_filename)
destination_with_filename = join(this_model_directory, save_filename)
try:
with _session.get(url, stream=True, allow_redirects=True) as response:
@ -154,7 +195,7 @@ KNOWN_LORAS = [
]
KNOWN_CONTROLNETS = [
HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "OpenPoseXL2.safetensors"),
HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "OpenPoseXL2.safetensors", convert_to_16_bit=True, size=2502139104),
HuggingFile("thibaud/controlnet-openpose-sdxl-1.0", "control-lora-openposeXL2-rank256.safetensors"),
HuggingFile("comfyanonymous/ControlNet-v1-1_fp16_safetensors", "control_lora_rank128_v11e_sd15_ip2p_fp16.safetensors"),
HuggingFile("comfyanonymous/ControlNet-v1-1_fp16_safetensors", "control_lora_rank128_v11e_sd15_shuffle_fp16.safetensors"),
@ -226,6 +267,7 @@ KNOWN_CONTROLNETS = [
HuggingFile("lllyasviel/sd_control_collection", "t2i-adapter_xl_sketch.safetensors"),
HuggingFile("lllyasviel/sd_control_collection", "thibaud_xl_openpose.safetensors"),
HuggingFile("lllyasviel/sd_control_collection", "thibaud_xl_openpose_256lora.safetensors"),
HuggingFile("jschoormans/controlnet-densepose-sdxl", "diffusion_pytorch_model.safetensors", save_with_filename="controlnet-densepose-sdxl.safetensors", convert_to_16_bit=True, size=2502139104),
]
KNOWN_DIFF_CONTROLNETS = [

View File

@ -49,9 +49,11 @@ class HuggingFile:
save_with_filename: Optional[str] = None
alternate_filenames: List[str] = dataclasses.field(default_factory=list)
show_in_ui: Optional[bool] = True
convert_to_16_bit: Optional[bool] = False
size: Optional[int] = None
def __str__(self):
return split(self.filename)[-1]
return self.save_with_filename or split(self.filename)[-1]
class CivitStats(TypedDict):

View File

@ -276,6 +276,7 @@ class StringJoin(CustomNode):
RETURN_TYPES = ("STRING",)
CATEGORY = "api/openapi"
FUNCTION = "execute"
def execute(self, separator: str = "_", *args: str, **kwargs) -> ValidatedNodeResult:
sorted_keys = natsorted(kwargs.keys())
@ -323,6 +324,7 @@ class UriFormat(CustomNode):
}
RETURN_TYPES = ("URIS", "URIS")
RETURN_NAMES = ("URIS (FILES)", "URIS (META)")
FUNCTION = "execute"
CATEGORY = "api/openapi"