mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 23:00:51 +08:00
Improve model downloading from Hugging Face Hub
This commit is contained in:
parent
da21da1d8c
commit
3d67224937
@ -158,6 +158,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
|||||||
parser.add_argument("--otel-service-version", type=str, default=__version__, env_var="OTEL_SERVICE_VERSION", help="The version of the service or application that is generating telemetry data.")
|
parser.add_argument("--otel-service-version", type=str, default=__version__, env_var="OTEL_SERVICE_VERSION", help="The version of the service or application that is generating telemetry data.")
|
||||||
parser.add_argument("--otel-exporter-otlp-endpoint", type=str, default=None, env_var="OTEL_EXPORTER_OTLP_ENDPOINT", help="A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.")
|
parser.add_argument("--otel-exporter-otlp-endpoint", type=str, default=None, env_var="OTEL_EXPORTER_OTLP_ENDPOINT", help="A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.")
|
||||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||||
|
parser.add_argument("--force-hf-local-dir-mode", action="store_true", help="Download repos from huggingface.co to the models/huggingface directory with the \"local_dir\" argument instead of models/huggingface_cache with the \"cache_dir\" argument, recreating the traditional file structure.")
|
||||||
|
|
||||||
# now give plugins a chance to add configuration
|
# now give plugins a chance to add configuration
|
||||||
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
||||||
|
|||||||
@ -107,6 +107,7 @@ class Configuration(dict):
|
|||||||
otel_service_version (str): The version of the service or application that is generating telemetry data. Default: "0.0.1".
|
otel_service_version (str): The version of the service or application that is generating telemetry data. Default: "0.0.1".
|
||||||
otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.
|
otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.
|
||||||
force_channels_last (bool): Force channels last format when inferencing the models.
|
force_channels_last (bool): Force channels last format when inferencing the models.
|
||||||
|
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, **kwargs):
|
def __init__(self, **kwargs):
|
||||||
@ -178,6 +179,7 @@ class Configuration(dict):
|
|||||||
self.disable_known_models: bool = False
|
self.disable_known_models: bool = False
|
||||||
self.max_queue_size: int = 65536
|
self.max_queue_size: int = 65536
|
||||||
self.force_channels_last: bool = False
|
self.force_channels_last: bool = False
|
||||||
|
self.force_hf_local_dir_mode = False
|
||||||
|
|
||||||
# from opentracing docs
|
# from opentracing docs
|
||||||
self.otel_service_name: str = "comfyui"
|
self.otel_service_name: str = "comfyui"
|
||||||
|
|||||||
@ -57,8 +57,6 @@ class FolderNames:
|
|||||||
if isinstance(value, tuple):
|
if isinstance(value, tuple):
|
||||||
paths, supported_extensions = value
|
paths, supported_extensions = value
|
||||||
value = FolderPathsTuple(key, paths, supported_extensions)
|
value = FolderPathsTuple(key, paths, supported_extensions)
|
||||||
if key in self.contents:
|
|
||||||
value = self.contents[key] + value
|
|
||||||
self.contents[key] = value
|
self.contents[key] = value
|
||||||
|
|
||||||
def __len__(self):
|
def __len__(self):
|
||||||
@ -67,6 +65,9 @@ class FolderNames:
|
|||||||
def __iter__(self):
|
def __iter__(self):
|
||||||
return iter(self.contents)
|
return iter(self.contents)
|
||||||
|
|
||||||
|
def __delitem__(self, key):
|
||||||
|
del self.contents[key]
|
||||||
|
|
||||||
def items(self):
|
def items(self):
|
||||||
return self.contents.items()
|
return self.contents.items()
|
||||||
|
|
||||||
@ -110,6 +111,8 @@ folder_names_and_paths["custom_nodes"] = FolderPathsTuple("custom_nodes", [os.pa
|
|||||||
folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions))
|
folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions))
|
||||||
folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions))
|
folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions))
|
||||||
folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""})
|
folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""})
|
||||||
|
folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [os.path.join(models_dir, "huggingface")], {""})
|
||||||
|
folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [os.path.join(models_dir, "huggingface_cache")], {""})
|
||||||
|
|
||||||
output_directory = os.path.join(base_path, "output")
|
output_directory = os.path.join(base_path, "output")
|
||||||
temp_directory = os.path.join(base_path, "temp")
|
temp_directory = os.path.join(base_path, "temp")
|
||||||
|
|||||||
36
comfy/component_model/deprecation.py
Normal file
36
comfy/component_model/deprecation.py
Normal file
@ -0,0 +1,36 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import warnings
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def _deprecate_method(*, version: str, message: Optional[str] = None):
|
||||||
|
"""Decorator to issue warnings when using a deprecated method.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
version (`str`):
|
||||||
|
The version when deprecated arguments will result in error.
|
||||||
|
message (`str`, *optional*):
|
||||||
|
Warning message that is raised. If not passed, a default warning message
|
||||||
|
will be created.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def _inner_deprecate_method(f):
|
||||||
|
name = f.__name__
|
||||||
|
if name == "__init__":
|
||||||
|
name = f.__qualname__.split(".")[0] # class name instead of method name
|
||||||
|
|
||||||
|
@wraps(f)
|
||||||
|
def inner_f(*args, **kwargs):
|
||||||
|
warning_message = (
|
||||||
|
f"'{name}' (from '{f.__module__}') is deprecated and will be removed from version '{version}'."
|
||||||
|
)
|
||||||
|
if message is not None:
|
||||||
|
warning_message += " " + message
|
||||||
|
warnings.warn(warning_message, FutureWarning)
|
||||||
|
return f(*args, **kwargs)
|
||||||
|
|
||||||
|
return inner_f
|
||||||
|
|
||||||
|
return _inner_deprecate_method
|
||||||
@ -7,6 +7,7 @@ from typing import Optional
|
|||||||
|
|
||||||
from aio_pika import connect_robust
|
from aio_pika import connect_robust
|
||||||
from aio_pika.patterns import JsonRPC
|
from aio_pika.patterns import JsonRPC
|
||||||
|
from aiohttp import web
|
||||||
from aiormq import AMQPConnectionError
|
from aiormq import AMQPConnectionError
|
||||||
|
|
||||||
from .distributed_progress import DistributedExecutorToClientProgress
|
from .distributed_progress import DistributedExecutorToClientProgress
|
||||||
@ -24,6 +25,7 @@ class DistributedPromptWorker:
|
|||||||
def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None,
|
def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None,
|
||||||
connection_uri: str = "amqp://localhost:5672/",
|
connection_uri: str = "amqp://localhost:5672/",
|
||||||
queue_name: str = "comfyui",
|
queue_name: str = "comfyui",
|
||||||
|
health_check_port: int = 9090,
|
||||||
loop: Optional[AbstractEventLoop] = None):
|
loop: Optional[AbstractEventLoop] = None):
|
||||||
self._rpc = None
|
self._rpc = None
|
||||||
self._channel = None
|
self._channel = None
|
||||||
@ -32,6 +34,29 @@ class DistributedPromptWorker:
|
|||||||
self._connection_uri = connection_uri
|
self._connection_uri = connection_uri
|
||||||
self._loop = loop or asyncio.get_event_loop()
|
self._loop = loop or asyncio.get_event_loop()
|
||||||
self._embedded_comfy_client = embedded_comfy_client
|
self._embedded_comfy_client = embedded_comfy_client
|
||||||
|
self._health_check_port = health_check_port
|
||||||
|
self._health_check_site = None
|
||||||
|
|
||||||
|
async def _health_check(self, request):
|
||||||
|
return web.Response(text="OK", content_type="text/plain")
|
||||||
|
|
||||||
|
async def _start_health_check_server(self):
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get('/health', self._health_check)
|
||||||
|
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
|
||||||
|
try:
|
||||||
|
site = web.TCPSite(runner, port=self._health_check_port)
|
||||||
|
await site.start()
|
||||||
|
self._health_check_site = site
|
||||||
|
logging.info(f"health check server started on port {self._health_check_port}")
|
||||||
|
except OSError as e:
|
||||||
|
if e.errno == 98:
|
||||||
|
logging.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway")
|
||||||
|
else:
|
||||||
|
logging.error(f"failed to start health check server with error {str(e)}, starting anyway")
|
||||||
|
|
||||||
@tracer.start_as_current_span("Do Work Item")
|
@tracer.start_as_current_span("Do Work Item")
|
||||||
async def _do_work_item(self, request: dict) -> dict:
|
async def _do_work_item(self, request: dict) -> dict:
|
||||||
@ -74,6 +99,7 @@ class DistributedPromptWorker:
|
|||||||
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||||
|
|
||||||
await self._rpc.register(self._queue_name, self._do_work_item)
|
await self._rpc.register(self._queue_name, self._do_work_item)
|
||||||
|
await self._start_health_check_server()
|
||||||
|
|
||||||
async def __aenter__(self) -> "DistributedPromptWorker":
|
async def __aenter__(self) -> "DistributedPromptWorker":
|
||||||
await self.init()
|
await self.init()
|
||||||
@ -83,6 +109,8 @@ class DistributedPromptWorker:
|
|||||||
await self._rpc.close()
|
await self._rpc.close()
|
||||||
await self._channel.close()
|
await self._channel.close()
|
||||||
await self._connection.close()
|
await self._connection.close()
|
||||||
|
if self._health_check_site:
|
||||||
|
await self._health_check_site.stop()
|
||||||
|
|
||||||
async def close(self):
|
async def close(self):
|
||||||
await self._close()
|
await self._close()
|
||||||
|
|||||||
@ -1,13 +1,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import operator
|
||||||
import os
|
import os
|
||||||
|
from functools import reduce
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from os.path import join
|
from os.path import join
|
||||||
from typing import List, Any, Optional, Union
|
from pathlib import Path
|
||||||
|
from typing import List, Any, Optional, Union, Sequence
|
||||||
|
|
||||||
import tqdm
|
import tqdm
|
||||||
from huggingface_hub import hf_hub_download, scan_cache_dir
|
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
|
||||||
from huggingface_hub.utils import GatedRepoError
|
from huggingface_hub.utils import GatedRepoError
|
||||||
from requests import Session
|
from requests import Session
|
||||||
from safetensors import safe_open
|
from safetensors import safe_open
|
||||||
@ -15,11 +18,13 @@ from safetensors.torch import save_file
|
|||||||
|
|
||||||
from .cli_args import args
|
from .cli_args import args
|
||||||
from .cmd import folder_paths
|
from .cmd import folder_paths
|
||||||
|
from .component_model.deprecation import _deprecate_method
|
||||||
from .interruption import InterruptProcessingException
|
from .interruption import InterruptProcessingException
|
||||||
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_
|
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_
|
||||||
from .utils import ProgressBar, comfy_tqdm
|
from .utils import ProgressBar, comfy_tqdm
|
||||||
|
|
||||||
_session = Session()
|
_session = Session()
|
||||||
|
_hf_fs = HfFileSystem()
|
||||||
|
|
||||||
|
|
||||||
def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]:
|
def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]:
|
||||||
@ -365,8 +370,121 @@ def add_known_models(folder_name: str, known_models: List[Union[CivitFile, Huggi
|
|||||||
return known_models
|
return known_models
|
||||||
|
|
||||||
|
|
||||||
|
@_deprecate_method(version="1.0.0", message="use get_huggingface_repo_list instead")
|
||||||
def huggingface_repos() -> List[str]:
|
def huggingface_repos() -> List[str]:
|
||||||
cache_info = scan_cache_dir()
|
return get_huggingface_repo_list()
|
||||||
existing_repo_ids = frozenset(cache_item.repo_id for cache_item in cache_info.repos if cache_item.repo_type == "model")
|
|
||||||
|
|
||||||
|
def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
|
||||||
|
if len(extra_cache_dirs) == 0:
|
||||||
|
extra_cache_dirs = folder_paths.get_folder_paths("huggingface_cache")
|
||||||
|
|
||||||
|
# all in cache directories
|
||||||
|
existing_repo_ids = frozenset(
|
||||||
|
cache_item.repo_id for cache_item in \
|
||||||
|
reduce(operator.or_,
|
||||||
|
map(lambda cache_info: cache_info.repos, [scan_cache_dir()] + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)]))
|
||||||
|
if cache_item.repo_type == "model"
|
||||||
|
)
|
||||||
|
|
||||||
|
# also check local-dir style directories
|
||||||
|
existing_local_dir_repos = set()
|
||||||
|
local_dirs = folder_paths.get_folder_paths("huggingface")
|
||||||
|
for local_dir_root in local_dirs:
|
||||||
|
# enumerate all the two-directory paths
|
||||||
|
if not os.path.isdir(local_dir_root):
|
||||||
|
continue
|
||||||
|
|
||||||
|
for user_dir in Path(local_dir_root).iterdir():
|
||||||
|
for model_dir in user_dir.iterdir():
|
||||||
|
try:
|
||||||
|
_hf_fs.resolve_path(str(user_dir / model_dir))
|
||||||
|
except Exception as exc_info:
|
||||||
|
logging.debug(f"HuggingFaceFS did not think this was a valid repo: {user_dir.name}/{model_dir.name} with error {exc_info}", exc_info)
|
||||||
|
existing_local_dir_repos.add(f"{user_dir.name}/{model_dir.name}")
|
||||||
|
|
||||||
known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS)
|
known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS)
|
||||||
return list(existing_repo_ids | known_repo_ids)
|
if args.disable_known_models:
|
||||||
|
return list(existing_repo_ids | existing_local_dir_repos)
|
||||||
|
else:
|
||||||
|
return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids)
|
||||||
|
|
||||||
|
|
||||||
|
def get_or_download_huggingface_repo(repo_id: str) -> Optional[str]:
|
||||||
|
cache_dirs = folder_paths.get_folder_paths("huggingface_cache")
|
||||||
|
local_dirs = folder_paths.get_folder_paths("huggingface")
|
||||||
|
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id)
|
||||||
|
|
||||||
|
local_dirs_cache_hit = len(local_dirs_snapshots) > 0
|
||||||
|
cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0
|
||||||
|
logging.debug(f"cache {'hit' if local_dirs_cache_hit or cache_dirs_cache_hit else 'miss'} for repo_id={repo_id} because local_dirs={local_dirs_cache_hit}, cache_dirs={cache_dirs_cache_hit}")
|
||||||
|
|
||||||
|
# if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download
|
||||||
|
if args.force_hf_local_dir_mode:
|
||||||
|
# todo: we still have to figure out a way to download things to the right places by default
|
||||||
|
if len(local_dirs_snapshots) > 0:
|
||||||
|
return local_dirs_snapshots[0]
|
||||||
|
elif not args.disable_known_models:
|
||||||
|
destination = os.path.join(local_dirs[0], repo_id)
|
||||||
|
logging.debug(f"downloading repo_id={repo_id}, local_dir={destination}")
|
||||||
|
return snapshot_download(repo_id, local_dir=destination)
|
||||||
|
|
||||||
|
snapshots = local_dirs_snapshots + cache_dirs_snapshots
|
||||||
|
if len(snapshots) > 0:
|
||||||
|
return snapshots[0]
|
||||||
|
elif not args.disable_known_models:
|
||||||
|
logging.debug(f"downloading repo_id={repo_id}")
|
||||||
|
return snapshot_download(repo_id)
|
||||||
|
|
||||||
|
# this repo was not found
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id):
|
||||||
|
local_dirs_snapshots = []
|
||||||
|
cache_dirs_snapshots = []
|
||||||
|
# find all the pre-existing downloads for this repo_id
|
||||||
|
try:
|
||||||
|
repo_files = set(_hf_fs.ls(repo_id, detail=False))
|
||||||
|
except:
|
||||||
|
repo_files = []
|
||||||
|
|
||||||
|
if len(repo_files) > 0:
|
||||||
|
for local_dir in local_dirs:
|
||||||
|
local_path = Path(local_dir) / repo_id
|
||||||
|
local_files = set(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file())
|
||||||
|
# fix path representation
|
||||||
|
local_files = set(f.replace("\\", "/") for f in local_files)
|
||||||
|
# remove .huggingface
|
||||||
|
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface"))
|
||||||
|
# local_files.issubsetof(repo_files)
|
||||||
|
if local_files.issubset(repo_files):
|
||||||
|
local_dirs_snapshots.append(str(local_path))
|
||||||
|
else:
|
||||||
|
# an empty repository or unknown repository info, trust that if the directory exists, it matches
|
||||||
|
for local_dir in local_dirs:
|
||||||
|
local_path = Path(local_dir) / repo_id
|
||||||
|
if local_path.is_dir():
|
||||||
|
local_dirs_snapshots.append(str(local_path))
|
||||||
|
|
||||||
|
for cache_dir in (None, *cache_dirs):
|
||||||
|
try:
|
||||||
|
cache_dirs_snapshots.append(snapshot_download(repo_id, local_files_only=True, cache_dir=cache_dir))
|
||||||
|
except FileNotFoundError:
|
||||||
|
continue
|
||||||
|
except:
|
||||||
|
continue
|
||||||
|
return cache_dirs_snapshots, local_dirs_snapshots
|
||||||
|
|
||||||
|
|
||||||
|
def _delete_repo_from_huggingface_cache(repo_id: str, cache_dir: Optional[str] = None) -> List[str]:
|
||||||
|
results = scan_cache_dir(cache_dir)
|
||||||
|
matching = [repo for repo in results.repos if repo.repo_id == repo_id]
|
||||||
|
if len(matching) == 0:
|
||||||
|
return []
|
||||||
|
revisions: List[str] = []
|
||||||
|
for repo in matching:
|
||||||
|
for revision_info in repo.revisions:
|
||||||
|
revisions.append(revision_info.commit_hash)
|
||||||
|
results.delete_revisions(*revisions).execute()
|
||||||
|
return revisions
|
||||||
|
|||||||
@ -43,7 +43,6 @@ class HuggingFile:
|
|||||||
Attributes:
|
Attributes:
|
||||||
repo_id (str): The Huggingface repository of a known file
|
repo_id (str): The Huggingface repository of a known file
|
||||||
filename (str): The path to the known file in the repository
|
filename (str): The path to the known file in the repository
|
||||||
show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter
|
|
||||||
"""
|
"""
|
||||||
repo_id: str
|
repo_id: str
|
||||||
filename: str
|
filename: str
|
||||||
|
|||||||
@ -26,7 +26,7 @@ from ..cli_args import args
|
|||||||
from ..cmd import folder_paths, latent_preview
|
from ..cmd import folder_paths, latent_preview
|
||||||
from ..execution_context import current_execution_context
|
from ..execution_context import current_execution_context
|
||||||
from ..images import open_image
|
from ..images import open_image
|
||||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
||||||
from ..nodes.common import MAX_RESOLUTION
|
from ..nodes.common import MAX_RESOLUTION
|
||||||
from .. import controlnet
|
from .. import controlnet
|
||||||
from ..open_exr import load_exr
|
from ..open_exr import load_exr
|
||||||
@ -517,7 +517,7 @@ class DiffusersLoader:
|
|||||||
if "model_index.json" in files:
|
if "model_index.json" in files:
|
||||||
paths.append(os.path.relpath(root, start=search_path))
|
paths.append(os.path.relpath(root, start=search_path))
|
||||||
|
|
||||||
paths += huggingface_repos()
|
paths += get_huggingface_repo_list()
|
||||||
paths = list(frozenset(paths))
|
paths = list(frozenset(paths))
|
||||||
return {"required": {"model_path": (paths,), }}
|
return {"required": {"model_path": (paths,), }}
|
||||||
|
|
||||||
|
|||||||
@ -9,15 +9,15 @@ from functools import reduce
|
|||||||
from typing import Any, Dict, Optional, List, Callable, Union
|
from typing import Any, Dict, Optional, List, Callable, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
||||||
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
|
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
|
||||||
LlavaNextForConditionalGeneration, LlavaNextProcessor, T5EncoderModel, AutoModel
|
LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel
|
||||||
from typing_extensions import TypedDict
|
from typing_extensions import TypedDict
|
||||||
|
|
||||||
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||||
from comfy.language.language_types import ProcessorResult
|
from comfy.language.language_types import ProcessorResult
|
||||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||||
from comfy.model_downloader import huggingface_repos
|
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
||||||
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
|
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
|
||||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||||
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
||||||
@ -197,7 +197,7 @@ class TransformersImageProcessorLoader(CustomNode):
|
|||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"ckpt_name": (huggingface_repos(),),
|
"ckpt_name": (get_huggingface_repo_list(),),
|
||||||
"subfolder": ("STRING", {}),
|
"subfolder": ("STRING", {}),
|
||||||
"model": ("MODEL", {}),
|
"model": ("MODEL", {}),
|
||||||
"overwrite_tokenizer": ("BOOLEAN", {"default": False}),
|
"overwrite_tokenizer": ("BOOLEAN", {"default": False}),
|
||||||
@ -212,6 +212,7 @@ class TransformersImageProcessorLoader(CustomNode):
|
|||||||
hub_kwargs = {}
|
hub_kwargs = {}
|
||||||
if subfolder is not None and subfolder != "":
|
if subfolder is not None and subfolder != "":
|
||||||
hub_kwargs["subfolder"] = subfolder
|
hub_kwargs["subfolder"] = subfolder
|
||||||
|
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
||||||
processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
|
processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
|
||||||
return model.patch_processor(processor, overwrite_tokenizer),
|
return model.patch_processor(processor, overwrite_tokenizer),
|
||||||
|
|
||||||
@ -221,7 +222,7 @@ class TransformersLoader(CustomNode):
|
|||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"ckpt_name": (huggingface_repos(),),
|
"ckpt_name": (get_huggingface_repo_list(),),
|
||||||
"subfolder": ("STRING", {})
|
"subfolder": ("STRING", {})
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@ -234,6 +235,8 @@ class TransformersLoader(CustomNode):
|
|||||||
hub_kwargs = {}
|
hub_kwargs = {}
|
||||||
if subfolder is not None and subfolder != "":
|
if subfolder is not None and subfolder != "":
|
||||||
hub_kwargs["subfolder"] = subfolder
|
hub_kwargs["subfolder"] = subfolder
|
||||||
|
|
||||||
|
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
||||||
with comfy_tqdm():
|
with comfy_tqdm():
|
||||||
from_pretrained_kwargs = {
|
from_pretrained_kwargs = {
|
||||||
"pretrained_model_name_or_path": ckpt_name,
|
"pretrained_model_name_or_path": ckpt_name,
|
||||||
|
|||||||
@ -112,7 +112,7 @@ class FloatRequestParameter(CustomNode):
|
|||||||
def INPUT_TYPES(cls) -> InputTypes:
|
def INPUT_TYPES(cls) -> InputTypes:
|
||||||
return {
|
return {
|
||||||
"required": {
|
"required": {
|
||||||
"value": ("FLOAT", {"default": 0})
|
"value": ("FLOAT", {"default": 0, "step": 0.00001, "round": 0.00001})
|
||||||
},
|
},
|
||||||
"optional": {
|
"optional": {
|
||||||
**_open_api_common_schema,
|
**_open_api_common_schema,
|
||||||
|
|||||||
@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
|
|||||||
|
|
||||||
import jwt
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
|
from aiohttp import ClientSession, ClientConnectorError
|
||||||
from testcontainers.rabbitmq import RabbitMqContainer
|
from testcontainers.rabbitmq import RabbitMqContainer
|
||||||
|
|
||||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||||
@ -108,3 +109,97 @@ async def test_frontend_backend_workers_validation_error_raises(frontend_backend
|
|||||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
|
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
|
||||||
with pytest.raises(Exception):
|
with pytest.raises(Exception):
|
||||||
await client.queue_prompt(prompt)
|
await client.queue_prompt(prompt)
|
||||||
|
|
||||||
|
|
||||||
|
async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0):
|
||||||
|
async with ClientSession() as session:
|
||||||
|
for _ in range(max_retries):
|
||||||
|
try:
|
||||||
|
async with session.get(url, timeout=1) as response:
|
||||||
|
if response.status == 200 and await response.text() == "OK":
|
||||||
|
return True
|
||||||
|
except Exception as exc_info:
|
||||||
|
pass
|
||||||
|
await asyncio.sleep(retry_delay)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_basic_queue_worker_with_health_check():
|
||||||
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
|
params = rabbitmq.get_connection_params()
|
||||||
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
|
health_check_port = 9090
|
||||||
|
|
||||||
|
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
|
||||||
|
# Test health check
|
||||||
|
health_check_url = f"http://localhost:{health_check_port}/health"
|
||||||
|
|
||||||
|
health_check_ok = await check_health(health_check_url)
|
||||||
|
assert health_check_ok, "Health check server did not start properly"
|
||||||
|
|
||||||
|
# Test the actual worker functionality
|
||||||
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
|
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
|
||||||
|
await distributed_queue.init()
|
||||||
|
|
||||||
|
queue_item = create_test_prompt()
|
||||||
|
res = await distributed_queue.put_async(queue_item)
|
||||||
|
|
||||||
|
assert res.item_id == queue_item.prompt_id
|
||||||
|
assert len(res.outputs) == 1
|
||||||
|
assert res.status is not None
|
||||||
|
assert res.status.status_str == "success"
|
||||||
|
|
||||||
|
await distributed_queue.close()
|
||||||
|
|
||||||
|
# Test that the health check server is stopped after the worker is closed
|
||||||
|
health_check_stopped = not await check_health(health_check_url, max_retries=1)
|
||||||
|
assert health_check_stopped, "Health check server did not stop properly"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_health_check_port_conflict():
|
||||||
|
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||||
|
params = rabbitmq.get_connection_params()
|
||||||
|
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||||
|
health_check_port = 9090
|
||||||
|
|
||||||
|
# Start a simple server to occupy the health check port
|
||||||
|
from aiohttp import web
|
||||||
|
async def dummy_handler(request):
|
||||||
|
return web.Response(text="Dummy")
|
||||||
|
|
||||||
|
app = web.Application()
|
||||||
|
app.router.add_get('/', dummy_handler)
|
||||||
|
runner = web.AppRunner(app)
|
||||||
|
await runner.setup()
|
||||||
|
site = web.TCPSite(runner, '0.0.0.0', health_check_port)
|
||||||
|
await site.start()
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Now try to start the DistributedPromptWorker
|
||||||
|
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
|
||||||
|
# The health check should be disabled, but the worker should still function
|
||||||
|
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||||
|
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
|
||||||
|
await distributed_queue.init()
|
||||||
|
|
||||||
|
queue_item = create_test_prompt()
|
||||||
|
res = await distributed_queue.put_async(queue_item)
|
||||||
|
|
||||||
|
assert res.item_id == queue_item.prompt_id
|
||||||
|
assert len(res.outputs) == 1
|
||||||
|
assert res.status is not None
|
||||||
|
assert res.status.status_str == "success"
|
||||||
|
|
||||||
|
await distributed_queue.close()
|
||||||
|
|
||||||
|
# The original server should still be running
|
||||||
|
async with ClientSession() as session:
|
||||||
|
async with session.get(f"http://localhost:{health_check_port}") as response:
|
||||||
|
assert response.status == 200
|
||||||
|
assert await response.text() == "Dummy"
|
||||||
|
|
||||||
|
finally:
|
||||||
|
await runner.cleanup()
|
||||||
|
|||||||
0
tests/downloader/__init__.py
Normal file
0
tests/downloader/__init__.py
Normal file
152
tests/downloader/test_huggingface_downloads.py
Normal file
152
tests/downloader/test_huggingface_downloads.py
Normal file
@ -0,0 +1,152 @@
|
|||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from comfy.cli_args import args
|
||||||
|
from comfy.cmd import folder_paths
|
||||||
|
from comfy.cmd.folder_paths import FolderPathsTuple
|
||||||
|
from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS, get_huggingface_repo_list, \
|
||||||
|
get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache
|
||||||
|
|
||||||
|
_gitattributes = """*.7z filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.model filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||||
|
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||||
|
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
def test_known_repos(tmp_path_factory):
|
||||||
|
test_cache_dir = tmp_path_factory.mktemp("huggingface_cache")
|
||||||
|
test_local_dir = tmp_path_factory.mktemp("huggingface_locals")
|
||||||
|
test_repo_id = "doctorpangloss/comfyui_downloader_test"
|
||||||
|
prev_huggingface = folder_paths.folder_names_and_paths["huggingface"]
|
||||||
|
prev_huggingface_cache = folder_paths.folder_names_and_paths["huggingface_cache"]
|
||||||
|
prev_hub_cache = os.getenv("HF_HUB_CACHE")
|
||||||
|
_delete_repo_from_huggingface_cache(test_repo_id)
|
||||||
|
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
|
||||||
|
try:
|
||||||
|
folder_paths.folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [test_local_dir], {""})
|
||||||
|
folder_paths.folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})
|
||||||
|
|
||||||
|
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||||
|
assert len(cache_hits) == len(locals_hits) == 0, "not downloaded yet"
|
||||||
|
|
||||||
|
# test downloading the repo and observing a cache hit on second access
|
||||||
|
existing_repos = get_huggingface_repo_list()
|
||||||
|
assert test_repo_id not in existing_repos
|
||||||
|
|
||||||
|
KNOWN_HUGGINGFACE_MODEL_REPOS.add(test_repo_id)
|
||||||
|
existing_repos = get_huggingface_repo_list()
|
||||||
|
assert test_repo_id in existing_repos
|
||||||
|
|
||||||
|
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||||
|
assert len(cache_hits) == len(locals_hits) == 0, "not downloaded yet"
|
||||||
|
|
||||||
|
# download to cache
|
||||||
|
path = get_or_download_huggingface_repo(test_repo_id)
|
||||||
|
assert path is not None
|
||||||
|
|
||||||
|
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||||
|
assert len(cache_hits) == 1, "should have downloaded to cache"
|
||||||
|
assert len(locals_hits) == 0, "should not have downloaded to a local dir"
|
||||||
|
|
||||||
|
# load from cache
|
||||||
|
args.disable_known_models = True
|
||||||
|
path = get_or_download_huggingface_repo(test_repo_id)
|
||||||
|
assert path is not None, "should have used local path"
|
||||||
|
|
||||||
|
# test deleting from cache
|
||||||
|
_delete_repo_from_huggingface_cache(test_repo_id)
|
||||||
|
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
|
||||||
|
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||||
|
assert len(cache_hits) == 0, "should have deleted from the cache"
|
||||||
|
assert len(locals_hits) == 0, "should not have downloaded to a local dir"
|
||||||
|
|
||||||
|
# test fails to download
|
||||||
|
path = get_or_download_huggingface_repo(test_repo_id)
|
||||||
|
assert path is None, "should not have downloaded since disable_known_models is True"
|
||||||
|
args.disable_known_models = False
|
||||||
|
|
||||||
|
# download to local dir
|
||||||
|
args.force_hf_local_dir_mode = True
|
||||||
|
path = get_or_download_huggingface_repo(test_repo_id)
|
||||||
|
assert path is not None
|
||||||
|
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||||
|
assert len(cache_hits) == 0
|
||||||
|
assert len(locals_hits) == 1, "should have downloaded to local dir"
|
||||||
|
|
||||||
|
# test loads from local dir
|
||||||
|
args.disable_known_models = True
|
||||||
|
path = get_or_download_huggingface_repo(test_repo_id)
|
||||||
|
assert path is not None
|
||||||
|
|
||||||
|
# test deleting local dir
|
||||||
|
expected_path = os.path.join(test_local_dir, test_repo_id)
|
||||||
|
shutil.rmtree(expected_path)
|
||||||
|
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||||
|
assert len(cache_hits) == 0
|
||||||
|
assert len(locals_hits) == 0
|
||||||
|
path = get_or_download_huggingface_repo(test_repo_id)
|
||||||
|
assert path is None, "should not download repo into local dir"
|
||||||
|
|
||||||
|
# recreating the test repo should be valid
|
||||||
|
os.makedirs(expected_path)
|
||||||
|
with open(os.path.join(expected_path, "test.txt"), "wt") as f:
|
||||||
|
f.write("OK")
|
||||||
|
with open(os.path.join(expected_path, ".gitattributes"), "wt") as f:
|
||||||
|
f.write(_gitattributes)
|
||||||
|
|
||||||
|
args.disable_known_models = False
|
||||||
|
# expect local hit
|
||||||
|
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||||
|
assert len(cache_hits) == 0
|
||||||
|
assert len(locals_hits) == 1
|
||||||
|
|
||||||
|
# should not download
|
||||||
|
path = get_or_download_huggingface_repo(test_repo_id)
|
||||||
|
assert path is not None
|
||||||
|
finally:
|
||||||
|
_delete_repo_from_huggingface_cache(test_repo_id)
|
||||||
|
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
|
||||||
|
KNOWN_HUGGINGFACE_MODEL_REPOS.remove(test_repo_id)
|
||||||
|
folder_paths.folder_names_and_paths["huggingface"] = prev_huggingface
|
||||||
|
folder_paths.folder_names_and_paths["huggingface_cache"] = prev_huggingface_cache
|
||||||
|
if prev_hub_cache is None and "HF_HUB_CACHE" in os.environ:
|
||||||
|
os.environ.pop("HF_HUB_CACHE")
|
||||||
|
elif prev_hub_cache is not None:
|
||||||
|
os.environ["HF_HUB_CACHE"] = prev_hub_cache
|
||||||
|
args.force_hf_local_dir_mode = False
|
||||||
|
args.disable_known_models = False
|
||||||
Loading…
Reference in New Issue
Block a user