Improve model downloader coherence with packages like controlnext-aux

This commit is contained in:
doctorpangloss 2024-04-30 14:28:44 -07:00
parent 0862863bc0
commit b94b90c1cc
4 changed files with 28 additions and 12 deletions

View File

@ -1,9 +1,10 @@
from __future__ import annotations
import dataclasses import dataclasses
import logging
import os import os
import posixpath
import sys import sys
import time import time
import logging
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
from pkg_resources import resource_filename from pkg_resources import resource_filename
@ -39,14 +40,17 @@ class FolderPathsTuple:
class FolderNames: class FolderNames:
def __init__(self): def __init__(self, default_new_folder_path: str):
self.contents: Dict[str, FolderPathsTuple] = dict() self.contents: Dict[str, FolderPathsTuple] = dict()
self.default_new_folder_path = default_new_folder_path
def __getitem__(self, item) -> FolderPathsTuple: def __getitem__(self, item) -> FolderPathsTuple:
if not isinstance(item, str): if not isinstance(item, str):
raise RuntimeError("expected folder path") raise RuntimeError("expected folder path")
if item not in self.contents: if item not in self.contents:
self.contents[item] = FolderPathsTuple(item, paths=[], supported_extensions=set()) default_path = os.path.join(self.default_new_folder_path, item)
os.makedirs(default_path, exist_ok=True)
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
return self.contents[item] return self.contents[item]
def __setitem__(self, key: str, value: FolderPathsTuple): def __setitem__(self, key: str, value: FolderPathsTuple):
@ -74,8 +78,6 @@ class FolderNames:
return self.contents.keys() return self.contents.keys()
folder_names_and_paths = FolderNames()
# todo: this should be initialized elsewhere # todo: this should be initialized elsewhere
if 'main.py' in sys.argv: if 'main.py' in sys.argv:
base_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) base_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))
@ -90,6 +92,7 @@ elif args.cwd is not None:
else: else:
base_path = os.getcwd() base_path = os.getcwd()
models_dir = os.path.join(base_path, "models") models_dir = os.path.join(base_path, "models")
folder_names_and_paths = FolderNames(models_dir)
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions)) folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"}) folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"})
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions)) folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))

View File

@ -9,6 +9,7 @@ It will enable command line argument parsing. If this isn't desired, you must au
""" """
import logging import logging
import os import os
import sys
import warnings import warnings
from opentelemetry import trace from opentelemetry import trace
@ -17,7 +18,7 @@ from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SpanExporter
from opentelemetry.semconv.resource import ResourceAttributes as ResAttrs from opentelemetry.semconv.resource import ResourceAttributes as ResAttrs
from .. import options from .. import options
@ -54,7 +55,15 @@ def _create_tracer():
sampler = ProgressSpanSampler() sampler = ProgressSpanSampler()
provider = TracerProvider(resource=resource, sampler=sampler) provider = TracerProvider(resource=resource, sampler=sampler)
otlp_exporter = OTLPSpanExporter() if args.otel_exporter_otlp_endpoint is not None else ConsoleSpanExporter() is_debugging = hasattr(sys, 'gettrace') and sys.gettrace() is not None
has_endpoint = args.otel_exporter_otlp_endpoint is not None
if has_endpoint:
otlp_exporter = OTLPSpanExporter()
elif is_debugging:
otlp_exporter = ConsoleSpanExporter()
else:
otlp_exporter = SpanExporter()
processor = BatchSpanProcessor(otlp_exporter) processor = BatchSpanProcessor(otlp_exporter)
provider.add_span_processor(processor) provider.add_span_processor(processor)

View File

@ -45,12 +45,12 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
if isinstance(known_file, HuggingFile): if isinstance(known_file, HuggingFile):
if known_file.save_with_filename is not None: if known_file.save_with_filename is not None:
linked_filename = known_file.save_with_filename linked_filename = known_file.save_with_filename
elif os.path.basename(known_file.filename) != known_file.filename: elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename:
linked_filename = os.path.basename(known_file.filename) linked_filename = os.path.basename(known_file.filename)
else: else:
linked_filename = None linked_filename = None
if linked_filename is not None and os.path.dirname(known_file.filename) == "": if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "":
# if the known file has an overridden linked name, save it into a repo_id sub directory # if the known file has an overridden linked name, save it into a repo_id sub directory
# this deals with situations like # this deals with situations like
# jschoormans/controlnet-densepose-sdxl repo having diffusion_pytorch_model.safetensors # jschoormans/controlnet-densepose-sdxl repo having diffusion_pytorch_model.safetensors
@ -73,6 +73,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
path = hf_hub_download(repo_id=known_file.repo_id, path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename, filename=known_file.filename,
local_dir=hf_destination_dir, local_dir=hf_destination_dir,
repo_type=known_file.repo_type,
resume_download=True) resume_download=True)
if known_file.convert_to_16_bit and file_size is not None and file_size != 0: if known_file.convert_to_16_bit and file_size is not None and file_size != 0:

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import dataclasses import dataclasses
from os.path import split from os.path import split
from typing import Optional, List, Sequence from typing import Optional, List, Sequence
from typing_extensions import TypedDict, NotRequired from typing_extensions import TypedDict, NotRequired
@ -51,6 +52,8 @@ class HuggingFile:
show_in_ui: Optional[bool] = True show_in_ui: Optional[bool] = True
convert_to_16_bit: Optional[bool] = False convert_to_16_bit: Optional[bool] = False
size: Optional[int] = None size: Optional[int] = None
force_save_in_repo_id: Optional[bool] = False
repo_type: Optional[str] = 'model'
def __str__(self): def __str__(self):
return self.save_with_filename or split(self.filename)[-1] return self.save_with_filename or split(self.filename)[-1]