From b94b90c1cced016c128268dfb3c25d47be5c1704 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 30 Apr 2024 14:28:44 -0700 Subject: [PATCH] Improve model downloader coherence with packages like controlnext-aux --- comfy/cmd/folder_paths.py | 19 +++++++++++-------- comfy/cmd/main_pre.py | 13 +++++++++++-- comfy/model_downloader.py | 5 +++-- comfy/model_downloader_types.py | 3 +++ 4 files changed, 28 insertions(+), 12 deletions(-) diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index 8a10ebd64..5c12f2726 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -1,9 +1,10 @@ +from __future__ import annotations + import dataclasses +import logging import os -import posixpath import sys import time -import logging from typing import Optional, List, Set, Dict, Any, Iterator, Sequence from pkg_resources import resource_filename @@ -39,14 +40,17 @@ class FolderPathsTuple: class FolderNames: - def __init__(self): + def __init__(self, default_new_folder_path: str): self.contents: Dict[str, FolderPathsTuple] = dict() + self.default_new_folder_path = default_new_folder_path def __getitem__(self, item) -> FolderPathsTuple: if not isinstance(item, str): raise RuntimeError("expected folder path") if item not in self.contents: - self.contents[item] = FolderPathsTuple(item, paths=[], supported_extensions=set()) + default_path = os.path.join(self.default_new_folder_path, item) + os.makedirs(default_path, exist_ok=True) + self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set()) return self.contents[item] def __setitem__(self, key: str, value: FolderPathsTuple): @@ -62,8 +66,8 @@ class FolderNames: return len(self.contents) def __iter__(self): - return iter(self.contents) - + return iter(self.contents) + def items(self): return self.contents.items() @@ -74,8 +78,6 @@ class FolderNames: return self.contents.keys() -folder_names_and_paths = FolderNames() - # todo: this should be initialized elsewhere if 'main.py' in sys.argv: base_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../..")) @@ -90,6 +92,7 @@ elif args.cwd is not None: else: base_path = os.getcwd() models_dir = os.path.join(base_path, "models") +folder_names_and_paths = FolderNames(models_dir) folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions)) folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"}) folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions)) diff --git a/comfy/cmd/main_pre.py b/comfy/cmd/main_pre.py index c3202765d..c3cdd7c9a 100644 --- a/comfy/cmd/main_pre.py +++ b/comfy/cmd/main_pre.py @@ -9,6 +9,7 @@ It will enable command line argument parsing. If this isn't desired, you must au """ import logging import os +import sys import warnings from opentelemetry import trace @@ -17,7 +18,7 @@ from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor from opentelemetry.sdk.resources import Resource from opentelemetry.sdk.trace import TracerProvider -from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter +from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SpanExporter from opentelemetry.semconv.resource import ResourceAttributes as ResAttrs from .. import options @@ -54,7 +55,15 @@ def _create_tracer(): sampler = ProgressSpanSampler() provider = TracerProvider(resource=resource, sampler=sampler) - otlp_exporter = OTLPSpanExporter() if args.otel_exporter_otlp_endpoint is not None else ConsoleSpanExporter() + is_debugging = hasattr(sys, 'gettrace') and sys.gettrace() is not None + has_endpoint = args.otel_exporter_otlp_endpoint is not None + + if has_endpoint: + otlp_exporter = OTLPSpanExporter() + elif is_debugging: + otlp_exporter = ConsoleSpanExporter() + else: + otlp_exporter = SpanExporter() processor = BatchSpanProcessor(otlp_exporter) provider.add_span_processor(processor) diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 8b76ab920..d2af3d22b 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -45,12 +45,12 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi if isinstance(known_file, HuggingFile): if known_file.save_with_filename is not None: linked_filename = known_file.save_with_filename - elif os.path.basename(known_file.filename) != known_file.filename: + elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename: linked_filename = os.path.basename(known_file.filename) else: linked_filename = None - if linked_filename is not None and os.path.dirname(known_file.filename) == "": + if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "": # if the known file has an overridden linked name, save it into a repo_id sub directory # this deals with situations like # jschoormans/controlnet-densepose-sdxl repo having diffusion_pytorch_model.safetensors @@ -73,6 +73,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi path = hf_hub_download(repo_id=known_file.repo_id, filename=known_file.filename, local_dir=hf_destination_dir, + repo_type=known_file.repo_type, resume_download=True) if known_file.convert_to_16_bit and file_size is not None and file_size != 0: diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index a3e18a652..b0ef948b0 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -3,6 +3,7 @@ from __future__ import annotations import dataclasses from os.path import split from typing import Optional, List, Sequence + from typing_extensions import TypedDict, NotRequired @@ -51,6 +52,8 @@ class HuggingFile: show_in_ui: Optional[bool] = True convert_to_16_bit: Optional[bool] = False size: Optional[int] = None + force_save_in_repo_id: Optional[bool] = False + repo_type: Optional[str] = 'model' def __str__(self): return self.save_with_filename or split(self.filename)[-1]