Improve model downloader coherence with packages like controlnext-aux

This commit is contained in:
doctorpangloss 2024-04-30 14:28:44 -07:00
parent 0862863bc0
commit b94b90c1cc
4 changed files with 28 additions and 12 deletions

View File

@ -1,9 +1,10 @@
from __future__ import annotations
import dataclasses
import logging
import os
import posixpath
import sys
import time
import logging
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
from pkg_resources import resource_filename
@ -39,14 +40,17 @@ class FolderPathsTuple:
class FolderNames:
def __init__(self):
def __init__(self, default_new_folder_path: str):
self.contents: Dict[str, FolderPathsTuple] = dict()
self.default_new_folder_path = default_new_folder_path
def __getitem__(self, item) -> FolderPathsTuple:
if not isinstance(item, str):
raise RuntimeError("expected folder path")
if item not in self.contents:
self.contents[item] = FolderPathsTuple(item, paths=[], supported_extensions=set())
default_path = os.path.join(self.default_new_folder_path, item)
os.makedirs(default_path, exist_ok=True)
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
return self.contents[item]
def __setitem__(self, key: str, value: FolderPathsTuple):
@ -62,8 +66,8 @@ class FolderNames:
return len(self.contents)
def __iter__(self):
return iter(self.contents)
return iter(self.contents)
def items(self):
return self.contents.items()
@ -74,8 +78,6 @@ class FolderNames:
return self.contents.keys()
folder_names_and_paths = FolderNames()
# todo: this should be initialized elsewhere
if 'main.py' in sys.argv:
base_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))
@ -90,6 +92,7 @@ elif args.cwd is not None:
else:
base_path = os.getcwd()
models_dir = os.path.join(base_path, "models")
folder_names_and_paths = FolderNames(models_dir)
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"})
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))

View File

@ -9,6 +9,7 @@ It will enable command line argument parsing. If this isn't desired, you must au
"""
import logging
import os
import sys
import warnings
from opentelemetry import trace
@ -17,7 +18,7 @@ from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SpanExporter
from opentelemetry.semconv.resource import ResourceAttributes as ResAttrs
from .. import options
@ -54,7 +55,15 @@ def _create_tracer():
sampler = ProgressSpanSampler()
provider = TracerProvider(resource=resource, sampler=sampler)
otlp_exporter = OTLPSpanExporter() if args.otel_exporter_otlp_endpoint is not None else ConsoleSpanExporter()
is_debugging = hasattr(sys, 'gettrace') and sys.gettrace() is not None
has_endpoint = args.otel_exporter_otlp_endpoint is not None
if has_endpoint:
otlp_exporter = OTLPSpanExporter()
elif is_debugging:
otlp_exporter = ConsoleSpanExporter()
else:
otlp_exporter = SpanExporter()
processor = BatchSpanProcessor(otlp_exporter)
provider.add_span_processor(processor)

View File

@ -45,12 +45,12 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
if isinstance(known_file, HuggingFile):
if known_file.save_with_filename is not None:
linked_filename = known_file.save_with_filename
elif os.path.basename(known_file.filename) != known_file.filename:
elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename:
linked_filename = os.path.basename(known_file.filename)
else:
linked_filename = None
if linked_filename is not None and os.path.dirname(known_file.filename) == "":
if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "":
# if the known file has an overridden linked name, save it into a repo_id sub directory
# this deals with situations like
# jschoormans/controlnet-densepose-sdxl repo having diffusion_pytorch_model.safetensors
@ -73,6 +73,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
path = hf_hub_download(repo_id=known_file.repo_id,
filename=known_file.filename,
local_dir=hf_destination_dir,
repo_type=known_file.repo_type,
resume_download=True)
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:

View File

@ -3,6 +3,7 @@ from __future__ import annotations
import dataclasses
from os.path import split
from typing import Optional, List, Sequence
from typing_extensions import TypedDict, NotRequired
@ -51,6 +52,8 @@ class HuggingFile:
show_in_ui: Optional[bool] = True
convert_to_16_bit: Optional[bool] = False
size: Optional[int] = None
force_save_in_repo_id: Optional[bool] = False
repo_type: Optional[str] = 'model'
def __str__(self):
return self.save_with_filename or split(self.filename)[-1]