mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
Improve model downloading from Hugging Face Hub
This commit is contained in:
parent
da21da1d8c
commit
3d67224937
@ -158,6 +158,7 @@ def _create_parser() -> EnhancedConfigArgParser:
|
||||
parser.add_argument("--otel-service-version", type=str, default=__version__, env_var="OTEL_SERVICE_VERSION", help="The version of the service or application that is generating telemetry data.")
|
||||
parser.add_argument("--otel-exporter-otlp-endpoint", type=str, default=None, env_var="OTEL_EXPORTER_OTLP_ENDPOINT", help="A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.")
|
||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||
parser.add_argument("--force-hf-local-dir-mode", action="store_true", help="Download repos from huggingface.co to the models/huggingface directory with the \"local_dir\" argument instead of models/huggingface_cache with the \"cache_dir\" argument, recreating the traditional file structure.")
|
||||
|
||||
# now give plugins a chance to add configuration
|
||||
for entry_point in entry_points().select(group='comfyui.custom_config'):
|
||||
|
||||
@ -107,6 +107,7 @@ class Configuration(dict):
|
||||
otel_service_version (str): The version of the service or application that is generating telemetry data. Default: "0.0.1".
|
||||
otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.
|
||||
force_channels_last (bool): Force channels last format when inferencing the models.
|
||||
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
|
||||
"""
|
||||
|
||||
def __init__(self, **kwargs):
|
||||
@ -178,6 +179,7 @@ class Configuration(dict):
|
||||
self.disable_known_models: bool = False
|
||||
self.max_queue_size: int = 65536
|
||||
self.force_channels_last: bool = False
|
||||
self.force_hf_local_dir_mode = False
|
||||
|
||||
# from opentracing docs
|
||||
self.otel_service_name: str = "comfyui"
|
||||
|
||||
@ -57,8 +57,6 @@ class FolderNames:
|
||||
if isinstance(value, tuple):
|
||||
paths, supported_extensions = value
|
||||
value = FolderPathsTuple(key, paths, supported_extensions)
|
||||
if key in self.contents:
|
||||
value = self.contents[key] + value
|
||||
self.contents[key] = value
|
||||
|
||||
def __len__(self):
|
||||
@ -67,6 +65,9 @@ class FolderNames:
|
||||
def __iter__(self):
|
||||
return iter(self.contents)
|
||||
|
||||
def __delitem__(self, key):
|
||||
del self.contents[key]
|
||||
|
||||
def items(self):
|
||||
return self.contents.items()
|
||||
|
||||
@ -110,6 +111,8 @@ folder_names_and_paths["custom_nodes"] = FolderPathsTuple("custom_nodes", [os.pa
|
||||
folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions))
|
||||
folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""})
|
||||
folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [os.path.join(models_dir, "huggingface")], {""})
|
||||
folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [os.path.join(models_dir, "huggingface_cache")], {""})
|
||||
|
||||
output_directory = os.path.join(base_path, "output")
|
||||
temp_directory = os.path.join(base_path, "temp")
|
||||
|
||||
36
comfy/component_model/deprecation.py
Normal file
36
comfy/component_model/deprecation.py
Normal file
@ -0,0 +1,36 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def _deprecate_method(*, version: str, message: Optional[str] = None):
|
||||
"""Decorator to issue warnings when using a deprecated method.
|
||||
|
||||
Args:
|
||||
version (`str`):
|
||||
The version when deprecated arguments will result in error.
|
||||
message (`str`, *optional*):
|
||||
Warning message that is raised. If not passed, a default warning message
|
||||
will be created.
|
||||
"""
|
||||
|
||||
def _inner_deprecate_method(f):
|
||||
name = f.__name__
|
||||
if name == "__init__":
|
||||
name = f.__qualname__.split(".")[0] # class name instead of method name
|
||||
|
||||
@wraps(f)
|
||||
def inner_f(*args, **kwargs):
|
||||
warning_message = (
|
||||
f"'{name}' (from '{f.__module__}') is deprecated and will be removed from version '{version}'."
|
||||
)
|
||||
if message is not None:
|
||||
warning_message += " " + message
|
||||
warnings.warn(warning_message, FutureWarning)
|
||||
return f(*args, **kwargs)
|
||||
|
||||
return inner_f
|
||||
|
||||
return _inner_deprecate_method
|
||||
@ -7,6 +7,7 @@ from typing import Optional
|
||||
|
||||
from aio_pika import connect_robust
|
||||
from aio_pika.patterns import JsonRPC
|
||||
from aiohttp import web
|
||||
from aiormq import AMQPConnectionError
|
||||
|
||||
from .distributed_progress import DistributedExecutorToClientProgress
|
||||
@ -24,6 +25,7 @@ class DistributedPromptWorker:
|
||||
def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None,
|
||||
connection_uri: str = "amqp://localhost:5672/",
|
||||
queue_name: str = "comfyui",
|
||||
health_check_port: int = 9090,
|
||||
loop: Optional[AbstractEventLoop] = None):
|
||||
self._rpc = None
|
||||
self._channel = None
|
||||
@ -32,6 +34,29 @@ class DistributedPromptWorker:
|
||||
self._connection_uri = connection_uri
|
||||
self._loop = loop or asyncio.get_event_loop()
|
||||
self._embedded_comfy_client = embedded_comfy_client
|
||||
self._health_check_port = health_check_port
|
||||
self._health_check_site = None
|
||||
|
||||
async def _health_check(self, request):
|
||||
return web.Response(text="OK", content_type="text/plain")
|
||||
|
||||
async def _start_health_check_server(self):
|
||||
app = web.Application()
|
||||
app.router.add_get('/health', self._health_check)
|
||||
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
|
||||
try:
|
||||
site = web.TCPSite(runner, port=self._health_check_port)
|
||||
await site.start()
|
||||
self._health_check_site = site
|
||||
logging.info(f"health check server started on port {self._health_check_port}")
|
||||
except OSError as e:
|
||||
if e.errno == 98:
|
||||
logging.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway")
|
||||
else:
|
||||
logging.error(f"failed to start health check server with error {str(e)}, starting anyway")
|
||||
|
||||
@tracer.start_as_current_span("Do Work Item")
|
||||
async def _do_work_item(self, request: dict) -> dict:
|
||||
@ -74,6 +99,7 @@ class DistributedPromptWorker:
|
||||
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
|
||||
|
||||
await self._rpc.register(self._queue_name, self._do_work_item)
|
||||
await self._start_health_check_server()
|
||||
|
||||
async def __aenter__(self) -> "DistributedPromptWorker":
|
||||
await self.init()
|
||||
@ -83,6 +109,8 @@ class DistributedPromptWorker:
|
||||
await self._rpc.close()
|
||||
await self._channel.close()
|
||||
await self._connection.close()
|
||||
if self._health_check_site:
|
||||
await self._health_check_site.stop()
|
||||
|
||||
async def close(self):
|
||||
await self._close()
|
||||
|
||||
@ -1,13 +1,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import operator
|
||||
import os
|
||||
from functools import reduce
|
||||
from itertools import chain
|
||||
from os.path import join
|
||||
from typing import List, Any, Optional, Union
|
||||
from pathlib import Path
|
||||
from typing import List, Any, Optional, Union, Sequence
|
||||
|
||||
import tqdm
|
||||
from huggingface_hub import hf_hub_download, scan_cache_dir
|
||||
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
|
||||
from huggingface_hub.utils import GatedRepoError
|
||||
from requests import Session
|
||||
from safetensors import safe_open
|
||||
@ -15,11 +18,13 @@ from safetensors.torch import save_file
|
||||
|
||||
from .cli_args import args
|
||||
from .cmd import folder_paths
|
||||
from .component_model.deprecation import _deprecate_method
|
||||
from .interruption import InterruptProcessingException
|
||||
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_
|
||||
from .utils import ProgressBar, comfy_tqdm
|
||||
|
||||
_session = Session()
|
||||
_hf_fs = HfFileSystem()
|
||||
|
||||
|
||||
def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]:
|
||||
@ -365,8 +370,121 @@ def add_known_models(folder_name: str, known_models: List[Union[CivitFile, Huggi
|
||||
return known_models
|
||||
|
||||
|
||||
@_deprecate_method(version="1.0.0", message="use get_huggingface_repo_list instead")
|
||||
def huggingface_repos() -> List[str]:
|
||||
cache_info = scan_cache_dir()
|
||||
existing_repo_ids = frozenset(cache_item.repo_id for cache_item in cache_info.repos if cache_item.repo_type == "model")
|
||||
return get_huggingface_repo_list()
|
||||
|
||||
|
||||
def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
|
||||
if len(extra_cache_dirs) == 0:
|
||||
extra_cache_dirs = folder_paths.get_folder_paths("huggingface_cache")
|
||||
|
||||
# all in cache directories
|
||||
existing_repo_ids = frozenset(
|
||||
cache_item.repo_id for cache_item in \
|
||||
reduce(operator.or_,
|
||||
map(lambda cache_info: cache_info.repos, [scan_cache_dir()] + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)]))
|
||||
if cache_item.repo_type == "model"
|
||||
)
|
||||
|
||||
# also check local-dir style directories
|
||||
existing_local_dir_repos = set()
|
||||
local_dirs = folder_paths.get_folder_paths("huggingface")
|
||||
for local_dir_root in local_dirs:
|
||||
# enumerate all the two-directory paths
|
||||
if not os.path.isdir(local_dir_root):
|
||||
continue
|
||||
|
||||
for user_dir in Path(local_dir_root).iterdir():
|
||||
for model_dir in user_dir.iterdir():
|
||||
try:
|
||||
_hf_fs.resolve_path(str(user_dir / model_dir))
|
||||
except Exception as exc_info:
|
||||
logging.debug(f"HuggingFaceFS did not think this was a valid repo: {user_dir.name}/{model_dir.name} with error {exc_info}", exc_info)
|
||||
existing_local_dir_repos.add(f"{user_dir.name}/{model_dir.name}")
|
||||
|
||||
known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS)
|
||||
return list(existing_repo_ids | known_repo_ids)
|
||||
if args.disable_known_models:
|
||||
return list(existing_repo_ids | existing_local_dir_repos)
|
||||
else:
|
||||
return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids)
|
||||
|
||||
|
||||
def get_or_download_huggingface_repo(repo_id: str) -> Optional[str]:
|
||||
cache_dirs = folder_paths.get_folder_paths("huggingface_cache")
|
||||
local_dirs = folder_paths.get_folder_paths("huggingface")
|
||||
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id)
|
||||
|
||||
local_dirs_cache_hit = len(local_dirs_snapshots) > 0
|
||||
cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0
|
||||
logging.debug(f"cache {'hit' if local_dirs_cache_hit or cache_dirs_cache_hit else 'miss'} for repo_id={repo_id} because local_dirs={local_dirs_cache_hit}, cache_dirs={cache_dirs_cache_hit}")
|
||||
|
||||
# if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download
|
||||
if args.force_hf_local_dir_mode:
|
||||
# todo: we still have to figure out a way to download things to the right places by default
|
||||
if len(local_dirs_snapshots) > 0:
|
||||
return local_dirs_snapshots[0]
|
||||
elif not args.disable_known_models:
|
||||
destination = os.path.join(local_dirs[0], repo_id)
|
||||
logging.debug(f"downloading repo_id={repo_id}, local_dir={destination}")
|
||||
return snapshot_download(repo_id, local_dir=destination)
|
||||
|
||||
snapshots = local_dirs_snapshots + cache_dirs_snapshots
|
||||
if len(snapshots) > 0:
|
||||
return snapshots[0]
|
||||
elif not args.disable_known_models:
|
||||
logging.debug(f"downloading repo_id={repo_id}")
|
||||
return snapshot_download(repo_id)
|
||||
|
||||
# this repo was not found
|
||||
return None
|
||||
|
||||
|
||||
def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id):
|
||||
local_dirs_snapshots = []
|
||||
cache_dirs_snapshots = []
|
||||
# find all the pre-existing downloads for this repo_id
|
||||
try:
|
||||
repo_files = set(_hf_fs.ls(repo_id, detail=False))
|
||||
except:
|
||||
repo_files = []
|
||||
|
||||
if len(repo_files) > 0:
|
||||
for local_dir in local_dirs:
|
||||
local_path = Path(local_dir) / repo_id
|
||||
local_files = set(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file())
|
||||
# fix path representation
|
||||
local_files = set(f.replace("\\", "/") for f in local_files)
|
||||
# remove .huggingface
|
||||
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface"))
|
||||
# local_files.issubsetof(repo_files)
|
||||
if local_files.issubset(repo_files):
|
||||
local_dirs_snapshots.append(str(local_path))
|
||||
else:
|
||||
# an empty repository or unknown repository info, trust that if the directory exists, it matches
|
||||
for local_dir in local_dirs:
|
||||
local_path = Path(local_dir) / repo_id
|
||||
if local_path.is_dir():
|
||||
local_dirs_snapshots.append(str(local_path))
|
||||
|
||||
for cache_dir in (None, *cache_dirs):
|
||||
try:
|
||||
cache_dirs_snapshots.append(snapshot_download(repo_id, local_files_only=True, cache_dir=cache_dir))
|
||||
except FileNotFoundError:
|
||||
continue
|
||||
except:
|
||||
continue
|
||||
return cache_dirs_snapshots, local_dirs_snapshots
|
||||
|
||||
|
||||
def _delete_repo_from_huggingface_cache(repo_id: str, cache_dir: Optional[str] = None) -> List[str]:
|
||||
results = scan_cache_dir(cache_dir)
|
||||
matching = [repo for repo in results.repos if repo.repo_id == repo_id]
|
||||
if len(matching) == 0:
|
||||
return []
|
||||
revisions: List[str] = []
|
||||
for repo in matching:
|
||||
for revision_info in repo.revisions:
|
||||
revisions.append(revision_info.commit_hash)
|
||||
results.delete_revisions(*revisions).execute()
|
||||
return revisions
|
||||
|
||||
@ -43,7 +43,6 @@ class HuggingFile:
|
||||
Attributes:
|
||||
repo_id (str): The Huggingface repository of a known file
|
||||
filename (str): The path to the known file in the repository
|
||||
show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter
|
||||
"""
|
||||
repo_id: str
|
||||
filename: str
|
||||
|
||||
@ -26,7 +26,7 @@ from ..cli_args import args
|
||||
from ..cmd import folder_paths, latent_preview
|
||||
from ..execution_context import current_execution_context
|
||||
from ..images import open_image
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
||||
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
|
||||
from ..nodes.common import MAX_RESOLUTION
|
||||
from .. import controlnet
|
||||
from ..open_exr import load_exr
|
||||
@ -517,7 +517,7 @@ class DiffusersLoader:
|
||||
if "model_index.json" in files:
|
||||
paths.append(os.path.relpath(root, start=search_path))
|
||||
|
||||
paths += huggingface_repos()
|
||||
paths += get_huggingface_repo_list()
|
||||
paths = list(frozenset(paths))
|
||||
return {"required": {"model_path": (paths,), }}
|
||||
|
||||
|
||||
@ -9,15 +9,15 @@ from functools import reduce
|
||||
from typing import Any, Dict, Optional, List, Callable, Union
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
||||
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
|
||||
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
|
||||
LlavaNextForConditionalGeneration, LlavaNextProcessor, T5EncoderModel, AutoModel
|
||||
LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
|
||||
from comfy.language.language_types import ProcessorResult
|
||||
from comfy.language.transformers_model_management import TransformersManagedModel
|
||||
from comfy.model_downloader import huggingface_repos
|
||||
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
|
||||
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
|
||||
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
|
||||
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
|
||||
@ -197,7 +197,7 @@ class TransformersImageProcessorLoader(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (huggingface_repos(),),
|
||||
"ckpt_name": (get_huggingface_repo_list(),),
|
||||
"subfolder": ("STRING", {}),
|
||||
"model": ("MODEL", {}),
|
||||
"overwrite_tokenizer": ("BOOLEAN", {"default": False}),
|
||||
@ -212,6 +212,7 @@ class TransformersImageProcessorLoader(CustomNode):
|
||||
hub_kwargs = {}
|
||||
if subfolder is not None and subfolder != "":
|
||||
hub_kwargs["subfolder"] = subfolder
|
||||
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
||||
processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
|
||||
return model.patch_processor(processor, overwrite_tokenizer),
|
||||
|
||||
@ -221,7 +222,7 @@ class TransformersLoader(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"ckpt_name": (huggingface_repos(),),
|
||||
"ckpt_name": (get_huggingface_repo_list(),),
|
||||
"subfolder": ("STRING", {})
|
||||
},
|
||||
}
|
||||
@ -234,6 +235,8 @@ class TransformersLoader(CustomNode):
|
||||
hub_kwargs = {}
|
||||
if subfolder is not None and subfolder != "":
|
||||
hub_kwargs["subfolder"] = subfolder
|
||||
|
||||
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
|
||||
with comfy_tqdm():
|
||||
from_pretrained_kwargs = {
|
||||
"pretrained_model_name_or_path": ckpt_name,
|
||||
|
||||
@ -112,7 +112,7 @@ class FloatRequestParameter(CustomNode):
|
||||
def INPUT_TYPES(cls) -> InputTypes:
|
||||
return {
|
||||
"required": {
|
||||
"value": ("FLOAT", {"default": 0})
|
||||
"value": ("FLOAT", {"default": 0, "step": 0.00001, "round": 0.00001})
|
||||
},
|
||||
"optional": {
|
||||
**_open_api_common_schema,
|
||||
|
||||
@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import jwt
|
||||
import pytest
|
||||
from aiohttp import ClientSession, ClientConnectorError
|
||||
from testcontainers.rabbitmq import RabbitMqContainer
|
||||
|
||||
from comfy.client.aio_client import AsyncRemoteComfyClient
|
||||
@ -108,3 +109,97 @@ async def test_frontend_backend_workers_validation_error_raises(frontend_backend
|
||||
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
|
||||
with pytest.raises(Exception):
|
||||
await client.queue_prompt(prompt)
|
||||
|
||||
|
||||
async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0):
|
||||
async with ClientSession() as session:
|
||||
for _ in range(max_retries):
|
||||
try:
|
||||
async with session.get(url, timeout=1) as response:
|
||||
if response.status == 200 and await response.text() == "OK":
|
||||
return True
|
||||
except Exception as exc_info:
|
||||
pass
|
||||
await asyncio.sleep(retry_delay)
|
||||
return False
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_queue_worker_with_health_check():
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
health_check_port = 9090
|
||||
|
||||
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
|
||||
# Test health check
|
||||
health_check_url = f"http://localhost:{health_check_port}/health"
|
||||
|
||||
health_check_ok = await check_health(health_check_url)
|
||||
assert health_check_ok, "Health check server did not start properly"
|
||||
|
||||
# Test the actual worker functionality
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
|
||||
await distributed_queue.init()
|
||||
|
||||
queue_item = create_test_prompt()
|
||||
res = await distributed_queue.put_async(queue_item)
|
||||
|
||||
assert res.item_id == queue_item.prompt_id
|
||||
assert len(res.outputs) == 1
|
||||
assert res.status is not None
|
||||
assert res.status.status_str == "success"
|
||||
|
||||
await distributed_queue.close()
|
||||
|
||||
# Test that the health check server is stopped after the worker is closed
|
||||
health_check_stopped = not await check_health(health_check_url, max_retries=1)
|
||||
assert health_check_stopped, "Health check server did not stop properly"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_health_check_port_conflict():
|
||||
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
|
||||
params = rabbitmq.get_connection_params()
|
||||
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
|
||||
health_check_port = 9090
|
||||
|
||||
# Start a simple server to occupy the health check port
|
||||
from aiohttp import web
|
||||
async def dummy_handler(request):
|
||||
return web.Response(text="Dummy")
|
||||
|
||||
app = web.Application()
|
||||
app.router.add_get('/', dummy_handler)
|
||||
runner = web.AppRunner(app)
|
||||
await runner.setup()
|
||||
site = web.TCPSite(runner, '0.0.0.0', health_check_port)
|
||||
await site.start()
|
||||
|
||||
try:
|
||||
# Now try to start the DistributedPromptWorker
|
||||
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
|
||||
# The health check should be disabled, but the worker should still function
|
||||
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
|
||||
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
|
||||
await distributed_queue.init()
|
||||
|
||||
queue_item = create_test_prompt()
|
||||
res = await distributed_queue.put_async(queue_item)
|
||||
|
||||
assert res.item_id == queue_item.prompt_id
|
||||
assert len(res.outputs) == 1
|
||||
assert res.status is not None
|
||||
assert res.status.status_str == "success"
|
||||
|
||||
await distributed_queue.close()
|
||||
|
||||
# The original server should still be running
|
||||
async with ClientSession() as session:
|
||||
async with session.get(f"http://localhost:{health_check_port}") as response:
|
||||
assert response.status == 200
|
||||
assert await response.text() == "Dummy"
|
||||
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
|
||||
0
tests/downloader/__init__.py
Normal file
0
tests/downloader/__init__.py
Normal file
152
tests/downloader/test_huggingface_downloads.py
Normal file
152
tests/downloader/test_huggingface_downloads.py
Normal file
@ -0,0 +1,152 @@
|
||||
import os
|
||||
import shutil
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.cmd import folder_paths
|
||||
from comfy.cmd.folder_paths import FolderPathsTuple
|
||||
from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS, get_huggingface_repo_list, \
|
||||
get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache
|
||||
|
||||
_gitattributes = """*.7z filter=lfs diff=lfs merge=lfs -text
|
||||
*.arrow filter=lfs diff=lfs merge=lfs -text
|
||||
*.bin filter=lfs diff=lfs merge=lfs -text
|
||||
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
||||
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
||||
*.ftz filter=lfs diff=lfs merge=lfs -text
|
||||
*.gz filter=lfs diff=lfs merge=lfs -text
|
||||
*.h5 filter=lfs diff=lfs merge=lfs -text
|
||||
*.joblib filter=lfs diff=lfs merge=lfs -text
|
||||
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
||||
*.model filter=lfs diff=lfs merge=lfs -text
|
||||
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
||||
*.npy filter=lfs diff=lfs merge=lfs -text
|
||||
*.npz filter=lfs diff=lfs merge=lfs -text
|
||||
*.onnx filter=lfs diff=lfs merge=lfs -text
|
||||
*.ot filter=lfs diff=lfs merge=lfs -text
|
||||
*.parquet filter=lfs diff=lfs merge=lfs -text
|
||||
*.pb filter=lfs diff=lfs merge=lfs -text
|
||||
*.pickle filter=lfs diff=lfs merge=lfs -text
|
||||
*.pkl filter=lfs diff=lfs merge=lfs -text
|
||||
*.pt filter=lfs diff=lfs merge=lfs -text
|
||||
*.pth filter=lfs diff=lfs merge=lfs -text
|
||||
*.rar filter=lfs diff=lfs merge=lfs -text
|
||||
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
||||
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
||||
*.tar filter=lfs diff=lfs merge=lfs -text
|
||||
*.tflite filter=lfs diff=lfs merge=lfs -text
|
||||
*.tgz filter=lfs diff=lfs merge=lfs -text
|
||||
*.wasm filter=lfs diff=lfs merge=lfs -text
|
||||
*.xz filter=lfs diff=lfs merge=lfs -text
|
||||
*.zip filter=lfs diff=lfs merge=lfs -text
|
||||
*.zst filter=lfs diff=lfs merge=lfs -text
|
||||
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
||||
"""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_known_repos(tmp_path_factory):
|
||||
test_cache_dir = tmp_path_factory.mktemp("huggingface_cache")
|
||||
test_local_dir = tmp_path_factory.mktemp("huggingface_locals")
|
||||
test_repo_id = "doctorpangloss/comfyui_downloader_test"
|
||||
prev_huggingface = folder_paths.folder_names_and_paths["huggingface"]
|
||||
prev_huggingface_cache = folder_paths.folder_names_and_paths["huggingface_cache"]
|
||||
prev_hub_cache = os.getenv("HF_HUB_CACHE")
|
||||
_delete_repo_from_huggingface_cache(test_repo_id)
|
||||
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
|
||||
try:
|
||||
folder_paths.folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [test_local_dir], {""})
|
||||
folder_paths.folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})
|
||||
|
||||
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||
assert len(cache_hits) == len(locals_hits) == 0, "not downloaded yet"
|
||||
|
||||
# test downloading the repo and observing a cache hit on second access
|
||||
existing_repos = get_huggingface_repo_list()
|
||||
assert test_repo_id not in existing_repos
|
||||
|
||||
KNOWN_HUGGINGFACE_MODEL_REPOS.add(test_repo_id)
|
||||
existing_repos = get_huggingface_repo_list()
|
||||
assert test_repo_id in existing_repos
|
||||
|
||||
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||
assert len(cache_hits) == len(locals_hits) == 0, "not downloaded yet"
|
||||
|
||||
# download to cache
|
||||
path = get_or_download_huggingface_repo(test_repo_id)
|
||||
assert path is not None
|
||||
|
||||
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||
assert len(cache_hits) == 1, "should have downloaded to cache"
|
||||
assert len(locals_hits) == 0, "should not have downloaded to a local dir"
|
||||
|
||||
# load from cache
|
||||
args.disable_known_models = True
|
||||
path = get_or_download_huggingface_repo(test_repo_id)
|
||||
assert path is not None, "should have used local path"
|
||||
|
||||
# test deleting from cache
|
||||
_delete_repo_from_huggingface_cache(test_repo_id)
|
||||
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
|
||||
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||
assert len(cache_hits) == 0, "should have deleted from the cache"
|
||||
assert len(locals_hits) == 0, "should not have downloaded to a local dir"
|
||||
|
||||
# test fails to download
|
||||
path = get_or_download_huggingface_repo(test_repo_id)
|
||||
assert path is None, "should not have downloaded since disable_known_models is True"
|
||||
args.disable_known_models = False
|
||||
|
||||
# download to local dir
|
||||
args.force_hf_local_dir_mode = True
|
||||
path = get_or_download_huggingface_repo(test_repo_id)
|
||||
assert path is not None
|
||||
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||
assert len(cache_hits) == 0
|
||||
assert len(locals_hits) == 1, "should have downloaded to local dir"
|
||||
|
||||
# test loads from local dir
|
||||
args.disable_known_models = True
|
||||
path = get_or_download_huggingface_repo(test_repo_id)
|
||||
assert path is not None
|
||||
|
||||
# test deleting local dir
|
||||
expected_path = os.path.join(test_local_dir, test_repo_id)
|
||||
shutil.rmtree(expected_path)
|
||||
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||
assert len(cache_hits) == 0
|
||||
assert len(locals_hits) == 0
|
||||
path = get_or_download_huggingface_repo(test_repo_id)
|
||||
assert path is None, "should not download repo into local dir"
|
||||
|
||||
# recreating the test repo should be valid
|
||||
os.makedirs(expected_path)
|
||||
with open(os.path.join(expected_path, "test.txt"), "wt") as f:
|
||||
f.write("OK")
|
||||
with open(os.path.join(expected_path, ".gitattributes"), "wt") as f:
|
||||
f.write(_gitattributes)
|
||||
|
||||
args.disable_known_models = False
|
||||
# expect local hit
|
||||
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
|
||||
assert len(cache_hits) == 0
|
||||
assert len(locals_hits) == 1
|
||||
|
||||
# should not download
|
||||
path = get_or_download_huggingface_repo(test_repo_id)
|
||||
assert path is not None
|
||||
finally:
|
||||
_delete_repo_from_huggingface_cache(test_repo_id)
|
||||
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
|
||||
KNOWN_HUGGINGFACE_MODEL_REPOS.remove(test_repo_id)
|
||||
folder_paths.folder_names_and_paths["huggingface"] = prev_huggingface
|
||||
folder_paths.folder_names_and_paths["huggingface_cache"] = prev_huggingface_cache
|
||||
if prev_hub_cache is None and "HF_HUB_CACHE" in os.environ:
|
||||
os.environ.pop("HF_HUB_CACHE")
|
||||
elif prev_hub_cache is not None:
|
||||
os.environ["HF_HUB_CACHE"] = prev_hub_cache
|
||||
args.force_hf_local_dir_mode = False
|
||||
args.disable_known_models = False
|
||||
Loading…
Reference in New Issue
Block a user