From 3d672249377f9ac6cb6366b571ac66d5d38dcb48 Mon Sep 17 00:00:00 2001 From: doctorpangloss <@hiddenswitch.com> Date: Tue, 9 Jul 2024 12:57:33 -0700 Subject: [PATCH] Improve model downloading from Hugging Face Hub --- comfy/cli_args.py | 1 + comfy/cli_args_types.py | 2 + comfy/cmd/folder_paths.py | 7 +- comfy/component_model/deprecation.py | 36 +++++ .../distributed/distributed_prompt_worker.py | 28 ++++ comfy/model_downloader.py | 128 ++++++++++++++- comfy/model_downloader_types.py | 1 - comfy/nodes/base_nodes.py | 4 +- comfy_extras/nodes/nodes_language.py | 13 +- comfy_extras/nodes/nodes_open_api.py | 2 +- tests/distributed/test_distributed_queue.py | 95 +++++++++++ tests/downloader/__init__.py | 0 .../downloader/test_huggingface_downloads.py | 152 ++++++++++++++++++ 13 files changed, 453 insertions(+), 16 deletions(-) create mode 100644 comfy/component_model/deprecation.py create mode 100644 tests/downloader/__init__.py create mode 100644 tests/downloader/test_huggingface_downloads.py diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 0b5ebd28d..903ed9408 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -158,6 +158,7 @@ def _create_parser() -> EnhancedConfigArgParser: parser.add_argument("--otel-service-version", type=str, default=__version__, env_var="OTEL_SERVICE_VERSION", help="The version of the service or application that is generating telemetry data.") parser.add_argument("--otel-exporter-otlp-endpoint", type=str, default=None, env_var="OTEL_EXPORTER_OTLP_ENDPOINT", help="A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.") parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") + parser.add_argument("--force-hf-local-dir-mode", action="store_true", help="Download repos from huggingface.co to the models/huggingface directory with the \"local_dir\" argument instead of models/huggingface_cache with the \"cache_dir\" argument, recreating the traditional file structure.") # now give plugins a chance to add configuration for entry_point in entry_points().select(group='comfyui.custom_config'): diff --git a/comfy/cli_args_types.py b/comfy/cli_args_types.py index f49fb88a9..bc57a583d 100644 --- a/comfy/cli_args_types.py +++ b/comfy/cli_args_types.py @@ -107,6 +107,7 @@ class Configuration(dict): otel_service_version (str): The version of the service or application that is generating telemetry data. Default: "0.0.1". otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint. force_channels_last (bool): Force channels last format when inferencing the models. + force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure. """ def __init__(self, **kwargs): @@ -178,6 +179,7 @@ class Configuration(dict): self.disable_known_models: bool = False self.max_queue_size: int = 65536 self.force_channels_last: bool = False + self.force_hf_local_dir_mode = False # from opentracing docs self.otel_service_name: str = "comfyui" diff --git a/comfy/cmd/folder_paths.py b/comfy/cmd/folder_paths.py index 37a187024..84457015b 100644 --- a/comfy/cmd/folder_paths.py +++ b/comfy/cmd/folder_paths.py @@ -57,8 +57,6 @@ class FolderNames: if isinstance(value, tuple): paths, supported_extensions = value value = FolderPathsTuple(key, paths, supported_extensions) - if key in self.contents: - value = self.contents[key] + value self.contents[key] = value def __len__(self): @@ -67,6 +65,9 @@ class FolderNames: def __iter__(self): return iter(self.contents) + def __delitem__(self, key): + del self.contents[key] + def items(self): return self.contents.items() @@ -110,6 +111,8 @@ folder_names_and_paths["custom_nodes"] = FolderPathsTuple("custom_nodes", [os.pa folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions)) folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions)) folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""}) +folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [os.path.join(models_dir, "huggingface")], {""}) +folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [os.path.join(models_dir, "huggingface_cache")], {""}) output_directory = os.path.join(base_path, "output") temp_directory = os.path.join(base_path, "temp") diff --git a/comfy/component_model/deprecation.py b/comfy/component_model/deprecation.py new file mode 100644 index 000000000..ace51a626 --- /dev/null +++ b/comfy/component_model/deprecation.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +import warnings +from functools import wraps +from typing import Optional + + +def _deprecate_method(*, version: str, message: Optional[str] = None): + """Decorator to issue warnings when using a deprecated method. + + Args: + version (`str`): + The version when deprecated arguments will result in error. + message (`str`, *optional*): + Warning message that is raised. If not passed, a default warning message + will be created. + """ + + def _inner_deprecate_method(f): + name = f.__name__ + if name == "__init__": + name = f.__qualname__.split(".")[0] # class name instead of method name + + @wraps(f) + def inner_f(*args, **kwargs): + warning_message = ( + f"'{name}' (from '{f.__module__}') is deprecated and will be removed from version '{version}'." + ) + if message is not None: + warning_message += " " + message + warnings.warn(warning_message, FutureWarning) + return f(*args, **kwargs) + + return inner_f + + return _inner_deprecate_method diff --git a/comfy/distributed/distributed_prompt_worker.py b/comfy/distributed/distributed_prompt_worker.py index fccac5f01..747358ec6 100644 --- a/comfy/distributed/distributed_prompt_worker.py +++ b/comfy/distributed/distributed_prompt_worker.py @@ -7,6 +7,7 @@ from typing import Optional from aio_pika import connect_robust from aio_pika.patterns import JsonRPC +from aiohttp import web from aiormq import AMQPConnectionError from .distributed_progress import DistributedExecutorToClientProgress @@ -24,6 +25,7 @@ class DistributedPromptWorker: def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None, connection_uri: str = "amqp://localhost:5672/", queue_name: str = "comfyui", + health_check_port: int = 9090, loop: Optional[AbstractEventLoop] = None): self._rpc = None self._channel = None @@ -32,6 +34,29 @@ class DistributedPromptWorker: self._connection_uri = connection_uri self._loop = loop or asyncio.get_event_loop() self._embedded_comfy_client = embedded_comfy_client + self._health_check_port = health_check_port + self._health_check_site = None + + async def _health_check(self, request): + return web.Response(text="OK", content_type="text/plain") + + async def _start_health_check_server(self): + app = web.Application() + app.router.add_get('/health', self._health_check) + + runner = web.AppRunner(app) + await runner.setup() + + try: + site = web.TCPSite(runner, port=self._health_check_port) + await site.start() + self._health_check_site = site + logging.info(f"health check server started on port {self._health_check_port}") + except OSError as e: + if e.errno == 98: + logging.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway") + else: + logging.error(f"failed to start health check server with error {str(e)}, starting anyway") @tracer.start_as_current_span("Do Work Item") async def _do_work_item(self, request: dict) -> dict: @@ -74,6 +99,7 @@ class DistributedPromptWorker: await self._exit_stack.enter_async_context(self._embedded_comfy_client) await self._rpc.register(self._queue_name, self._do_work_item) + await self._start_health_check_server() async def __aenter__(self) -> "DistributedPromptWorker": await self.init() @@ -83,6 +109,8 @@ class DistributedPromptWorker: await self._rpc.close() await self._channel.close() await self._connection.close() + if self._health_check_site: + await self._health_check_site.stop() async def close(self): await self._close() diff --git a/comfy/model_downloader.py b/comfy/model_downloader.py index 68f669642..42a781c00 100644 --- a/comfy/model_downloader.py +++ b/comfy/model_downloader.py @@ -1,13 +1,16 @@ from __future__ import annotations import logging +import operator import os +from functools import reduce from itertools import chain from os.path import join -from typing import List, Any, Optional, Union +from pathlib import Path +from typing import List, Any, Optional, Union, Sequence import tqdm -from huggingface_hub import hf_hub_download, scan_cache_dir +from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem from huggingface_hub.utils import GatedRepoError from requests import Session from safetensors import safe_open @@ -15,11 +18,13 @@ from safetensors.torch import save_file from .cli_args import args from .cmd import folder_paths +from .component_model.deprecation import _deprecate_method from .interruption import InterruptProcessingException from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_ from .utils import ProgressBar, comfy_tqdm _session = Session() +_hf_fs = HfFileSystem() def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]: @@ -365,8 +370,121 @@ def add_known_models(folder_name: str, known_models: List[Union[CivitFile, Huggi return known_models +@_deprecate_method(version="1.0.0", message="use get_huggingface_repo_list instead") def huggingface_repos() -> List[str]: - cache_info = scan_cache_dir() - existing_repo_ids = frozenset(cache_item.repo_id for cache_item in cache_info.repos if cache_item.repo_type == "model") + return get_huggingface_repo_list() + + +def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]: + if len(extra_cache_dirs) == 0: + extra_cache_dirs = folder_paths.get_folder_paths("huggingface_cache") + + # all in cache directories + existing_repo_ids = frozenset( + cache_item.repo_id for cache_item in \ + reduce(operator.or_, + map(lambda cache_info: cache_info.repos, [scan_cache_dir()] + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)])) + if cache_item.repo_type == "model" + ) + + # also check local-dir style directories + existing_local_dir_repos = set() + local_dirs = folder_paths.get_folder_paths("huggingface") + for local_dir_root in local_dirs: + # enumerate all the two-directory paths + if not os.path.isdir(local_dir_root): + continue + + for user_dir in Path(local_dir_root).iterdir(): + for model_dir in user_dir.iterdir(): + try: + _hf_fs.resolve_path(str(user_dir / model_dir)) + except Exception as exc_info: + logging.debug(f"HuggingFaceFS did not think this was a valid repo: {user_dir.name}/{model_dir.name} with error {exc_info}", exc_info) + existing_local_dir_repos.add(f"{user_dir.name}/{model_dir.name}") + known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS) - return list(existing_repo_ids | known_repo_ids) + if args.disable_known_models: + return list(existing_repo_ids | existing_local_dir_repos) + else: + return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids) + + +def get_or_download_huggingface_repo(repo_id: str) -> Optional[str]: + cache_dirs = folder_paths.get_folder_paths("huggingface_cache") + local_dirs = folder_paths.get_folder_paths("huggingface") + cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id) + + local_dirs_cache_hit = len(local_dirs_snapshots) > 0 + cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0 + logging.debug(f"cache {'hit' if local_dirs_cache_hit or cache_dirs_cache_hit else 'miss'} for repo_id={repo_id} because local_dirs={local_dirs_cache_hit}, cache_dirs={cache_dirs_cache_hit}") + + # if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download + if args.force_hf_local_dir_mode: + # todo: we still have to figure out a way to download things to the right places by default + if len(local_dirs_snapshots) > 0: + return local_dirs_snapshots[0] + elif not args.disable_known_models: + destination = os.path.join(local_dirs[0], repo_id) + logging.debug(f"downloading repo_id={repo_id}, local_dir={destination}") + return snapshot_download(repo_id, local_dir=destination) + + snapshots = local_dirs_snapshots + cache_dirs_snapshots + if len(snapshots) > 0: + return snapshots[0] + elif not args.disable_known_models: + logging.debug(f"downloading repo_id={repo_id}") + return snapshot_download(repo_id) + + # this repo was not found + return None + + +def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id): + local_dirs_snapshots = [] + cache_dirs_snapshots = [] + # find all the pre-existing downloads for this repo_id + try: + repo_files = set(_hf_fs.ls(repo_id, detail=False)) + except: + repo_files = [] + + if len(repo_files) > 0: + for local_dir in local_dirs: + local_path = Path(local_dir) / repo_id + local_files = set(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file()) + # fix path representation + local_files = set(f.replace("\\", "/") for f in local_files) + # remove .huggingface + local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface")) + # local_files.issubsetof(repo_files) + if local_files.issubset(repo_files): + local_dirs_snapshots.append(str(local_path)) + else: + # an empty repository or unknown repository info, trust that if the directory exists, it matches + for local_dir in local_dirs: + local_path = Path(local_dir) / repo_id + if local_path.is_dir(): + local_dirs_snapshots.append(str(local_path)) + + for cache_dir in (None, *cache_dirs): + try: + cache_dirs_snapshots.append(snapshot_download(repo_id, local_files_only=True, cache_dir=cache_dir)) + except FileNotFoundError: + continue + except: + continue + return cache_dirs_snapshots, local_dirs_snapshots + + +def _delete_repo_from_huggingface_cache(repo_id: str, cache_dir: Optional[str] = None) -> List[str]: + results = scan_cache_dir(cache_dir) + matching = [repo for repo in results.repos if repo.repo_id == repo_id] + if len(matching) == 0: + return [] + revisions: List[str] = [] + for repo in matching: + for revision_info in repo.revisions: + revisions.append(revision_info.commit_hash) + results.delete_revisions(*revisions).execute() + return revisions diff --git a/comfy/model_downloader_types.py b/comfy/model_downloader_types.py index 4ff6df0a7..8469d1efc 100644 --- a/comfy/model_downloader_types.py +++ b/comfy/model_downloader_types.py @@ -43,7 +43,6 @@ class HuggingFile: Attributes: repo_id (str): The Huggingface repository of a known file filename (str): The path to the known file in the repository - show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter """ repo_id: str filename: str diff --git a/comfy/nodes/base_nodes.py b/comfy/nodes/base_nodes.py index 00dbf1c39..7a9e33a55 100644 --- a/comfy/nodes/base_nodes.py +++ b/comfy/nodes/base_nodes.py @@ -26,7 +26,7 @@ from ..cli_args import args from ..cmd import folder_paths, latent_preview from ..execution_context import current_execution_context from ..images import open_image -from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS +from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS from ..nodes.common import MAX_RESOLUTION from .. import controlnet from ..open_exr import load_exr @@ -517,7 +517,7 @@ class DiffusersLoader: if "model_index.json" in files: paths.append(os.path.relpath(root, start=search_path)) - paths += huggingface_repos() + paths += get_huggingface_repo_list() paths = list(frozenset(paths)) return {"required": {"model_path": (paths,), }} diff --git a/comfy_extras/nodes/nodes_language.py b/comfy_extras/nodes/nodes_language.py index 8c7c86c38..b67236b1d 100644 --- a/comfy_extras/nodes/nodes_language.py +++ b/comfy_extras/nodes/nodes_language.py @@ -9,15 +9,15 @@ from functools import reduce from typing import Any, Dict, Optional, List, Callable, Union import torch -from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \ +from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \ PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \ - LlavaNextForConditionalGeneration, LlavaNextProcessor, T5EncoderModel, AutoModel + LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel from typing_extensions import TypedDict from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES from comfy.language.language_types import ProcessorResult from comfy.language.transformers_model_management import TransformersManagedModel -from comfy.model_downloader import huggingface_repos +from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar @@ -197,7 +197,7 @@ class TransformersImageProcessorLoader(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - "ckpt_name": (huggingface_repos(),), + "ckpt_name": (get_huggingface_repo_list(),), "subfolder": ("STRING", {}), "model": ("MODEL", {}), "overwrite_tokenizer": ("BOOLEAN", {"default": False}), @@ -212,6 +212,7 @@ class TransformersImageProcessorLoader(CustomNode): hub_kwargs = {} if subfolder is not None and subfolder != "": hub_kwargs["subfolder"] = subfolder + ckpt_name = get_or_download_huggingface_repo(ckpt_name) processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs) return model.patch_processor(processor, overwrite_tokenizer), @@ -221,7 +222,7 @@ class TransformersLoader(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - "ckpt_name": (huggingface_repos(),), + "ckpt_name": (get_huggingface_repo_list(),), "subfolder": ("STRING", {}) }, } @@ -234,6 +235,8 @@ class TransformersLoader(CustomNode): hub_kwargs = {} if subfolder is not None and subfolder != "": hub_kwargs["subfolder"] = subfolder + + ckpt_name = get_or_download_huggingface_repo(ckpt_name) with comfy_tqdm(): from_pretrained_kwargs = { "pretrained_model_name_or_path": ckpt_name, diff --git a/comfy_extras/nodes/nodes_open_api.py b/comfy_extras/nodes/nodes_open_api.py index efa2ead40..884b3d1a9 100644 --- a/comfy_extras/nodes/nodes_open_api.py +++ b/comfy_extras/nodes/nodes_open_api.py @@ -112,7 +112,7 @@ class FloatRequestParameter(CustomNode): def INPUT_TYPES(cls) -> InputTypes: return { "required": { - "value": ("FLOAT", {"default": 0}) + "value": ("FLOAT", {"default": 0, "step": 0.00001, "round": 0.00001}) }, "optional": { **_open_api_common_schema, diff --git a/tests/distributed/test_distributed_queue.py b/tests/distributed/test_distributed_queue.py index 90111a1a3..8df9d6d36 100644 --- a/tests/distributed/test_distributed_queue.py +++ b/tests/distributed/test_distributed_queue.py @@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor import jwt import pytest +from aiohttp import ClientSession, ClientConnectorError from testcontainers.rabbitmq import RabbitMqContainer from comfy.client.aio_client import AsyncRemoteComfyClient @@ -108,3 +109,97 @@ async def test_frontend_backend_workers_validation_error_raises(frontend_backend prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors") with pytest.raises(Exception): await client.queue_prompt(prompt) + + +async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0): + async with ClientSession() as session: + for _ in range(max_retries): + try: + async with session.get(url, timeout=1) as response: + if response.status == 200 and await response.text() == "OK": + return True + except Exception as exc_info: + pass + await asyncio.sleep(retry_delay) + return False + + +@pytest.mark.asyncio +async def test_basic_queue_worker_with_health_check(): + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + params = rabbitmq.get_connection_params() + connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + health_check_port = 9090 + + async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker: + # Test health check + health_check_url = f"http://localhost:{health_check_port}/health" + + health_check_ok = await check_health(health_check_url) + assert health_check_ok, "Health check server did not start properly" + + # Test the actual worker functionality + from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue + distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) + await distributed_queue.init() + + queue_item = create_test_prompt() + res = await distributed_queue.put_async(queue_item) + + assert res.item_id == queue_item.prompt_id + assert len(res.outputs) == 1 + assert res.status is not None + assert res.status.status_str == "success" + + await distributed_queue.close() + + # Test that the health check server is stopped after the worker is closed + health_check_stopped = not await check_health(health_check_url, max_retries=1) + assert health_check_stopped, "Health check server did not stop properly" + + +@pytest.mark.asyncio +async def test_health_check_port_conflict(): + with RabbitMqContainer("rabbitmq:latest") as rabbitmq: + params = rabbitmq.get_connection_params() + connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}" + health_check_port = 9090 + + # Start a simple server to occupy the health check port + from aiohttp import web + async def dummy_handler(request): + return web.Response(text="Dummy") + + app = web.Application() + app.router.add_get('/', dummy_handler) + runner = web.AppRunner(app) + await runner.setup() + site = web.TCPSite(runner, '0.0.0.0', health_check_port) + await site.start() + + try: + # Now try to start the DistributedPromptWorker + async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker: + # The health check should be disabled, but the worker should still function + from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue + distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri) + await distributed_queue.init() + + queue_item = create_test_prompt() + res = await distributed_queue.put_async(queue_item) + + assert res.item_id == queue_item.prompt_id + assert len(res.outputs) == 1 + assert res.status is not None + assert res.status.status_str == "success" + + await distributed_queue.close() + + # The original server should still be running + async with ClientSession() as session: + async with session.get(f"http://localhost:{health_check_port}") as response: + assert response.status == 200 + assert await response.text() == "Dummy" + + finally: + await runner.cleanup() diff --git a/tests/downloader/__init__.py b/tests/downloader/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/downloader/test_huggingface_downloads.py b/tests/downloader/test_huggingface_downloads.py new file mode 100644 index 000000000..95a97eeda --- /dev/null +++ b/tests/downloader/test_huggingface_downloads.py @@ -0,0 +1,152 @@ +import os +import shutil + +import pytest + +from comfy.cli_args import args +from comfy.cmd import folder_paths +from comfy.cmd.folder_paths import FolderPathsTuple +from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS, get_huggingface_repo_list, \ + get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache + +_gitattributes = """*.7z filter=lfs diff=lfs merge=lfs -text +*.arrow filter=lfs diff=lfs merge=lfs -text +*.bin filter=lfs diff=lfs merge=lfs -text +*.bz2 filter=lfs diff=lfs merge=lfs -text +*.ckpt filter=lfs diff=lfs merge=lfs -text +*.ftz filter=lfs diff=lfs merge=lfs -text +*.gz filter=lfs diff=lfs merge=lfs -text +*.h5 filter=lfs diff=lfs merge=lfs -text +*.joblib filter=lfs diff=lfs merge=lfs -text +*.lfs.* filter=lfs diff=lfs merge=lfs -text +*.mlmodel filter=lfs diff=lfs merge=lfs -text +*.model filter=lfs diff=lfs merge=lfs -text +*.msgpack filter=lfs diff=lfs merge=lfs -text +*.npy filter=lfs diff=lfs merge=lfs -text +*.npz filter=lfs diff=lfs merge=lfs -text +*.onnx filter=lfs diff=lfs merge=lfs -text +*.ot filter=lfs diff=lfs merge=lfs -text +*.parquet filter=lfs diff=lfs merge=lfs -text +*.pb filter=lfs diff=lfs merge=lfs -text +*.pickle filter=lfs diff=lfs merge=lfs -text +*.pkl filter=lfs diff=lfs merge=lfs -text +*.pt filter=lfs diff=lfs merge=lfs -text +*.pth filter=lfs diff=lfs merge=lfs -text +*.rar filter=lfs diff=lfs merge=lfs -text +*.safetensors filter=lfs diff=lfs merge=lfs -text +saved_model/**/* filter=lfs diff=lfs merge=lfs -text +*.tar.* filter=lfs diff=lfs merge=lfs -text +*.tar filter=lfs diff=lfs merge=lfs -text +*.tflite filter=lfs diff=lfs merge=lfs -text +*.tgz filter=lfs diff=lfs merge=lfs -text +*.wasm filter=lfs diff=lfs merge=lfs -text +*.xz filter=lfs diff=lfs merge=lfs -text +*.zip filter=lfs diff=lfs merge=lfs -text +*.zst filter=lfs diff=lfs merge=lfs -text +*tfevents* filter=lfs diff=lfs merge=lfs -text +""" + + +@pytest.mark.asyncio +def test_known_repos(tmp_path_factory): + test_cache_dir = tmp_path_factory.mktemp("huggingface_cache") + test_local_dir = tmp_path_factory.mktemp("huggingface_locals") + test_repo_id = "doctorpangloss/comfyui_downloader_test" + prev_huggingface = folder_paths.folder_names_and_paths["huggingface"] + prev_huggingface_cache = folder_paths.folder_names_and_paths["huggingface_cache"] + prev_hub_cache = os.getenv("HF_HUB_CACHE") + _delete_repo_from_huggingface_cache(test_repo_id) + _delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir) + try: + folder_paths.folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [test_local_dir], {""}) + folder_paths.folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [test_cache_dir], {""}) + + cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id) + assert len(cache_hits) == len(locals_hits) == 0, "not downloaded yet" + + # test downloading the repo and observing a cache hit on second access + existing_repos = get_huggingface_repo_list() + assert test_repo_id not in existing_repos + + KNOWN_HUGGINGFACE_MODEL_REPOS.add(test_repo_id) + existing_repos = get_huggingface_repo_list() + assert test_repo_id in existing_repos + + cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id) + assert len(cache_hits) == len(locals_hits) == 0, "not downloaded yet" + + # download to cache + path = get_or_download_huggingface_repo(test_repo_id) + assert path is not None + + cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id) + assert len(cache_hits) == 1, "should have downloaded to cache" + assert len(locals_hits) == 0, "should not have downloaded to a local dir" + + # load from cache + args.disable_known_models = True + path = get_or_download_huggingface_repo(test_repo_id) + assert path is not None, "should have used local path" + + # test deleting from cache + _delete_repo_from_huggingface_cache(test_repo_id) + _delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir) + cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id) + assert len(cache_hits) == 0, "should have deleted from the cache" + assert len(locals_hits) == 0, "should not have downloaded to a local dir" + + # test fails to download + path = get_or_download_huggingface_repo(test_repo_id) + assert path is None, "should not have downloaded since disable_known_models is True" + args.disable_known_models = False + + # download to local dir + args.force_hf_local_dir_mode = True + path = get_or_download_huggingface_repo(test_repo_id) + assert path is not None + cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id) + assert len(cache_hits) == 0 + assert len(locals_hits) == 1, "should have downloaded to local dir" + + # test loads from local dir + args.disable_known_models = True + path = get_or_download_huggingface_repo(test_repo_id) + assert path is not None + + # test deleting local dir + expected_path = os.path.join(test_local_dir, test_repo_id) + shutil.rmtree(expected_path) + cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id) + assert len(cache_hits) == 0 + assert len(locals_hits) == 0 + path = get_or_download_huggingface_repo(test_repo_id) + assert path is None, "should not download repo into local dir" + + # recreating the test repo should be valid + os.makedirs(expected_path) + with open(os.path.join(expected_path, "test.txt"), "wt") as f: + f.write("OK") + with open(os.path.join(expected_path, ".gitattributes"), "wt") as f: + f.write(_gitattributes) + + args.disable_known_models = False + # expect local hit + cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id) + assert len(cache_hits) == 0 + assert len(locals_hits) == 1 + + # should not download + path = get_or_download_huggingface_repo(test_repo_id) + assert path is not None + finally: + _delete_repo_from_huggingface_cache(test_repo_id) + _delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir) + KNOWN_HUGGINGFACE_MODEL_REPOS.remove(test_repo_id) + folder_paths.folder_names_and_paths["huggingface"] = prev_huggingface + folder_paths.folder_names_and_paths["huggingface_cache"] = prev_huggingface_cache + if prev_hub_cache is None and "HF_HUB_CACHE" in os.environ: + os.environ.pop("HF_HUB_CACHE") + elif prev_hub_cache is not None: + os.environ["HF_HUB_CACHE"] = prev_hub_cache + args.force_hf_local_dir_mode = False + args.disable_known_models = False