mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-28 15:20:25 +08:00
Improve model downloader coherence with packages like controlnext-aux
This commit is contained in:
parent
0862863bc0
commit
b94b90c1cc
@ -1,9 +1,10 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import dataclasses
|
import dataclasses
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import posixpath
|
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
import logging
|
|
||||||
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
|
from typing import Optional, List, Set, Dict, Any, Iterator, Sequence
|
||||||
|
|
||||||
from pkg_resources import resource_filename
|
from pkg_resources import resource_filename
|
||||||
@ -39,14 +40,17 @@ class FolderPathsTuple:
|
|||||||
|
|
||||||
|
|
||||||
class FolderNames:
|
class FolderNames:
|
||||||
def __init__(self):
|
def __init__(self, default_new_folder_path: str):
|
||||||
self.contents: Dict[str, FolderPathsTuple] = dict()
|
self.contents: Dict[str, FolderPathsTuple] = dict()
|
||||||
|
self.default_new_folder_path = default_new_folder_path
|
||||||
|
|
||||||
def __getitem__(self, item) -> FolderPathsTuple:
|
def __getitem__(self, item) -> FolderPathsTuple:
|
||||||
if not isinstance(item, str):
|
if not isinstance(item, str):
|
||||||
raise RuntimeError("expected folder path")
|
raise RuntimeError("expected folder path")
|
||||||
if item not in self.contents:
|
if item not in self.contents:
|
||||||
self.contents[item] = FolderPathsTuple(item, paths=[], supported_extensions=set())
|
default_path = os.path.join(self.default_new_folder_path, item)
|
||||||
|
os.makedirs(default_path, exist_ok=True)
|
||||||
|
self.contents[item] = FolderPathsTuple(item, paths=[default_path], supported_extensions=set())
|
||||||
return self.contents[item]
|
return self.contents[item]
|
||||||
|
|
||||||
def __setitem__(self, key: str, value: FolderPathsTuple):
|
def __setitem__(self, key: str, value: FolderPathsTuple):
|
||||||
@ -62,8 +66,8 @@ class FolderNames:
|
|||||||
return len(self.contents)
|
return len(self.contents)
|
||||||
|
|
||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self.contents)
|
return iter(self.contents)
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
return self.contents.items()
|
return self.contents.items()
|
||||||
|
|
||||||
@ -74,8 +78,6 @@ class FolderNames:
|
|||||||
return self.contents.keys()
|
return self.contents.keys()
|
||||||
|
|
||||||
|
|
||||||
folder_names_and_paths = FolderNames()
|
|
||||||
|
|
||||||
# todo: this should be initialized elsewhere
|
# todo: this should be initialized elsewhere
|
||||||
if 'main.py' in sys.argv:
|
if 'main.py' in sys.argv:
|
||||||
base_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))
|
base_path = os.path.realpath(os.path.join(os.path.dirname(os.path.realpath(__file__)), "../.."))
|
||||||
@ -90,6 +92,7 @@ elif args.cwd is not None:
|
|||||||
else:
|
else:
|
||||||
base_path = os.getcwd()
|
base_path = os.getcwd()
|
||||||
models_dir = os.path.join(base_path, "models")
|
models_dir = os.path.join(base_path, "models")
|
||||||
|
folder_names_and_paths = FolderNames(models_dir)
|
||||||
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
|
folder_names_and_paths["checkpoints"] = FolderPathsTuple("checkpoints", [os.path.join(models_dir, "checkpoints")], set(supported_pt_extensions))
|
||||||
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"})
|
folder_names_and_paths["configs"] = FolderPathsTuple("configs", [os.path.join(models_dir, "configs"), resource_filename("comfy", "configs/")], {".yaml"})
|
||||||
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
|
folder_names_and_paths["loras"] = FolderPathsTuple("loras", [os.path.join(models_dir, "loras")], set(supported_pt_extensions))
|
||||||
|
|||||||
@ -9,6 +9,7 @@ It will enable command line argument parsing. If this isn't desired, you must au
|
|||||||
"""
|
"""
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
import sys
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
from opentelemetry import trace
|
from opentelemetry import trace
|
||||||
@ -17,7 +18,7 @@ from opentelemetry.instrumentation.aio_pika import AioPikaInstrumentor
|
|||||||
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
|
from opentelemetry.instrumentation.aiohttp_server import AioHttpServerInstrumentor
|
||||||
from opentelemetry.sdk.resources import Resource
|
from opentelemetry.sdk.resources import Resource
|
||||||
from opentelemetry.sdk.trace import TracerProvider
|
from opentelemetry.sdk.trace import TracerProvider
|
||||||
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter
|
from opentelemetry.sdk.trace.export import BatchSpanProcessor, ConsoleSpanExporter, SpanExporter
|
||||||
from opentelemetry.semconv.resource import ResourceAttributes as ResAttrs
|
from opentelemetry.semconv.resource import ResourceAttributes as ResAttrs
|
||||||
|
|
||||||
from .. import options
|
from .. import options
|
||||||
@ -54,7 +55,15 @@ def _create_tracer():
|
|||||||
sampler = ProgressSpanSampler()
|
sampler = ProgressSpanSampler()
|
||||||
provider = TracerProvider(resource=resource, sampler=sampler)
|
provider = TracerProvider(resource=resource, sampler=sampler)
|
||||||
|
|
||||||
otlp_exporter = OTLPSpanExporter() if args.otel_exporter_otlp_endpoint is not None else ConsoleSpanExporter()
|
is_debugging = hasattr(sys, 'gettrace') and sys.gettrace() is not None
|
||||||
|
has_endpoint = args.otel_exporter_otlp_endpoint is not None
|
||||||
|
|
||||||
|
if has_endpoint:
|
||||||
|
otlp_exporter = OTLPSpanExporter()
|
||||||
|
elif is_debugging:
|
||||||
|
otlp_exporter = ConsoleSpanExporter()
|
||||||
|
else:
|
||||||
|
otlp_exporter = SpanExporter()
|
||||||
|
|
||||||
processor = BatchSpanProcessor(otlp_exporter)
|
processor = BatchSpanProcessor(otlp_exporter)
|
||||||
provider.add_span_processor(processor)
|
provider.add_span_processor(processor)
|
||||||
|
|||||||
@ -45,12 +45,12 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
|
|||||||
if isinstance(known_file, HuggingFile):
|
if isinstance(known_file, HuggingFile):
|
||||||
if known_file.save_with_filename is not None:
|
if known_file.save_with_filename is not None:
|
||||||
linked_filename = known_file.save_with_filename
|
linked_filename = known_file.save_with_filename
|
||||||
elif os.path.basename(known_file.filename) != known_file.filename:
|
elif not known_file.force_save_in_repo_id and os.path.basename(known_file.filename) != known_file.filename:
|
||||||
linked_filename = os.path.basename(known_file.filename)
|
linked_filename = os.path.basename(known_file.filename)
|
||||||
else:
|
else:
|
||||||
linked_filename = None
|
linked_filename = None
|
||||||
|
|
||||||
if linked_filename is not None and os.path.dirname(known_file.filename) == "":
|
if known_file.force_save_in_repo_id or linked_filename is not None and os.path.dirname(known_file.filename) == "":
|
||||||
# if the known file has an overridden linked name, save it into a repo_id sub directory
|
# if the known file has an overridden linked name, save it into a repo_id sub directory
|
||||||
# this deals with situations like
|
# this deals with situations like
|
||||||
# jschoormans/controlnet-densepose-sdxl repo having diffusion_pytorch_model.safetensors
|
# jschoormans/controlnet-densepose-sdxl repo having diffusion_pytorch_model.safetensors
|
||||||
@ -73,6 +73,7 @@ def get_or_download(folder_name: str, filename: str, known_files: List[HuggingFi
|
|||||||
path = hf_hub_download(repo_id=known_file.repo_id,
|
path = hf_hub_download(repo_id=known_file.repo_id,
|
||||||
filename=known_file.filename,
|
filename=known_file.filename,
|
||||||
local_dir=hf_destination_dir,
|
local_dir=hf_destination_dir,
|
||||||
|
repo_type=known_file.repo_type,
|
||||||
resume_download=True)
|
resume_download=True)
|
||||||
|
|
||||||
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
if known_file.convert_to_16_bit and file_size is not None and file_size != 0:
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from __future__ import annotations
|
|||||||
import dataclasses
|
import dataclasses
|
||||||
from os.path import split
|
from os.path import split
|
||||||
from typing import Optional, List, Sequence
|
from typing import Optional, List, Sequence
|
||||||
|
|
||||||
from typing_extensions import TypedDict, NotRequired
|
from typing_extensions import TypedDict, NotRequired
|
||||||
|
|
||||||
|
|
||||||
@ -51,6 +52,8 @@ class HuggingFile:
|
|||||||
show_in_ui: Optional[bool] = True
|
show_in_ui: Optional[bool] = True
|
||||||
convert_to_16_bit: Optional[bool] = False
|
convert_to_16_bit: Optional[bool] = False
|
||||||
size: Optional[int] = None
|
size: Optional[int] = None
|
||||||
|
force_save_in_repo_id: Optional[bool] = False
|
||||||
|
repo_type: Optional[str] = 'model'
|
||||||
|
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return self.save_with_filename or split(self.filename)[-1]
|
return self.save_with_filename or split(self.filename)[-1]
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user