mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Improve model downloader coherence with packages like controlnext-aux
This commit is contained in:
parent
0862863bc0
commit
b94b90c1cc
@ -1,9 +1,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
import posixpath
|
||||
import sys
|
||||
import time
|
||||
import logging
|
||||
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
|
||||
|
||||
from pkg_resources import resource_filename
|
||||
@ -39,14 +40,17 @@ class FolderPathsTuple:
|
||||
|
||||
|
||||
class FolderNames:
|
||||
def __init__(self):
|
||||
def __init__(self, default_new_folder_path: str):
|
||||
self.contents: Dict[str, FolderPathsTuple] = dict()
|
||||
self.default_new_folder_path = default_new_folder_path
|
||||
|
||||
def __getitem__(self, item) -> FolderPathsTuple:
|
||||
if not isinstance(item, str):
|
||||
raise RuntimeError("expected folder path")
|
||||
if item not in self.contents:
|
||||
self.contents[item] = FolderPathsTuple(item, paths=[], supported_extensions=set())
|
||||
default_path = os.path.join(self.default_new_folder_path, item)
|
||||
os.makedirs(default_path, exist_ok=True)
|
||||
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
|
||||
return self.contents[item]
|
||||
|
||||
def __setitem__(self, key: str, value: FolderPathsTuple):
|
||||
@ -62,8 +66,8 @@ class FolderNames:
|
||||
return len(self.contents)
|
||||
|
||||
def __iter__(self):
|
||||
return iter(self.contents)
|
||||
|
||||
return iter(self.contents)
|
||||
|
||||
def items(self):
|
||||
return self.contents.items()
|
||||
|
||||
@ -74,8 +78,6 @@ class FolderNames:
|
||||
return self.contents.keys()
|
||||
|
||||
|
||||
folder_names_and_paths = FolderNames()
|
||||
|
||||
# todo: this should be initialized elsewhere
|
||||
if 'main.py' in sys.argv:
|
||||
base_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))
|
||||
@ -90,6 +92,7 @@ elif args.cwd is not None:
|
||||
else:
|
||||
base_path = os.getcwd()
|
||||
models_dir = os.path.join(base_path, "models")
|
||||
folder_names_and_paths = FolderNames(models_dir)
|
||||
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"})
|
||||
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
|
||||
|
||||
@ -9,6 +9,7 @@ It will enable command line argument parsing. If this isn't desired, you must au
|
||||
"""
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
|
||||
from opentelemetry import trace
|
||||
@ -17,7 +18,7 @@ from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor
|
||||
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
|
||||
from opentelemetry.sdk.resources import Resource
|
||||
from opentelemetry.sdk.trace import TracerProvider
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
|
||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SpanExporter
|
||||
from opentelemetry.semconv.resource import ResourceAttributes as ResAttrs
|
||||
|
||||
from .. import options
|
||||
@ -54,7 +55,15 @@ def _create_tracer():
|
||||
sampler = ProgressSpanSampler()
|
||||
provider = TracerProvider(resource=resource, sampler=sampler)
|
||||
|
||||
otlp_exporter = OTLPSpanExporter() if args.otel_exporter_otlp_endpoint is not None else ConsoleSpanExporter()
|
||||
is_debugging = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||
has_endpoint = args.otel_exporter_otlp_endpoint is not None
|
||||
|
||||
if has_endpoint:
|
||||
otlp_exporter = OTLPSpanExporter()
|
||||
elif is_debugging:
|
||||
otlp_exporter = ConsoleSpanExporter()
|
||||
else:
|
||||
otlp_exporter = SpanExporter()
|
||||
|
||||
processor = BatchSpanProcessor(otlp_exporter)
|
||||
provider.add_span_processor(processor)
|
||||
|
||||
@ -45,12 +45,12 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
|
||||
if isinstance(known_file, HuggingFile):
|
||||
if known_file.save_with_filename is not None:
|
||||
linked_filename = known_file.save_with_filename
|
||||
elif os.path.basename(known_file.filename) != known_file.filename:
|
||||
elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename:
|
||||
linked_filename = os.path.basename(known_file.filename)
|
||||
else:
|
||||
linked_filename = None
|
||||
|
||||
if linked_filename is not None and os.path.dirname(known_file.filename) == "":
|
||||
if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "":
|
||||
# if the known file has an overridden linked name, save it into a repo_id sub directory
|
||||
# this deals with situations like
|
||||
# jschoormans/controlnet-densepose-sdxl repo having diffusion_pytorch_model.safetensors
|
||||
@ -73,6 +73,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
|
||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||
filename=known_file.filename,
|
||||
local_dir=hf_destination_dir,
|
||||
repo_type=known_file.repo_type,
|
||||
resume_download=True)
|
||||
|
||||
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
||||
|
||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
||||
import dataclasses
|
||||
from os.path import split
|
||||
from typing import Optional, List, Sequence
|
||||
|
||||
from typing_extensions import TypedDict, NotRequired
|
||||
|
||||
|
||||
@ -51,6 +52,8 @@ class HuggingFile:
|
||||
show_in_ui: Optional[bool] = True
|
||||
convert_to_16_bit: Optional[bool] = False
|
||||
size: Optional[int] = None
|
||||
force_save_in_repo_id: Optional[bool] = False
|
||||
repo_type: Optional[str] = 'model'
|
||||
|
||||
def __str__(self):
|
||||
return self.save_with_filename or split(self.filename)[-1]
|
||||
|
||||
Loading…
Reference in New Issue
Block a user