Improve model downloading from Hugging Face Hub

This commit is contained in:
doctorpangloss 2024-07-09 12:57:33 -07:00
parent da21da1d8c
commit 3d67224937
13 changed files with 453 additions and 16 deletions

View File

@ -158,6 +158,7 @@ def _create_parser() -> EnhancedConfigArgParser:
parser.add_argument("--otel-service-version", type=str, default=__version__, env_var="OTEL_SERVICE_VERSION", help="The version of the service or application that is generating telemetry data.")
parser.add_argument("--otel-exporter-otlp-endpoint", type=str, default=None, env_var="OTEL_EXPORTER_OTLP_ENDPOINT", help="A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--force-hf-local-dir-mode", action="store_true", help="Download repos from huggingface.co to the models/huggingface directory with the \"local_dir\" argument instead of models/huggingface_cache with the \"cache_dir\" argument, recreating the traditional file structure.")
# now give plugins a chance to add configuration
for entry_point in entry_points().select(group='comfyui.custom_config'):

View File

@ -107,6 +107,7 @@ class Configuration(dict):
otel_service_version (str): The version of the service or application that is generating telemetry data. Default: "0.0.1".
otel_exporter_otlp_endpoint (Optional[str]): A base endpoint URL for any signal type, with an optionally-specified port number. Helpful for when you're sending more than one signal to the same endpoint and want one environment variable to control the endpoint.
force_channels_last (bool): Force channels last format when inferencing the models.
force_hf_local_dir_mode (bool): Download repos from huggingface.co to the models/huggingface directory with the "local_dir" argument instead of models/huggingface_cache with the "cache_dir" argument, recreating the traditional file structure.
"""
def __init__(self, **kwargs):
@ -178,6 +179,7 @@ class Configuration(dict):
self.disable_known_models: bool = False
self.max_queue_size: int = 65536
self.force_channels_last: bool = False
self.force_hf_local_dir_mode = False
# from opentracing docs
self.otel_service_name: str = "comfyui"

View File

@ -57,8 +57,6 @@ class FolderNames:
if isinstance(value, tuple):
paths, supported_extensions = value
value = FolderPathsTuple(key, paths, supported_extensions)
if key in self.contents:
value = self.contents[key] + value
self.contents[key] = value
def __len__(self):
@ -67,6 +65,9 @@ class FolderNames:
def __iter__(self):
return iter(self.contents)
def __delitem__(self, key):
del self.contents[key]
def items(self):
return self.contents.items()
@ -110,6 +111,8 @@ folder_names_and_paths["custom_nodes"] = FolderPathsTuple("custom_nodes", [os.pa
folder_names_and_paths["hypernetworks"] = FolderPathsTuple("hypernetworks", [os.path.join(models_dir, "hypernetworks")], set(supported_pt_extensions))
folder_names_and_paths["photomaker"] = FolderPathsTuple("photomaker", [os.path.join(models_dir, "photomaker")], set(supported_pt_extensions))
folder_names_and_paths["classifiers"] = FolderPathsTuple("classifiers", [os.path.join(models_dir, "classifiers")], {""})
folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [os.path.join(models_dir, "huggingface")], {""})
folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [os.path.join(models_dir, "huggingface_cache")], {""})
output_directory = os.path.join(base_path, "output")
temp_directory = os.path.join(base_path, "temp")

View File

@ -0,0 +1,36 @@
from __future__ import annotations
import warnings
from functools import wraps
from typing import Optional
def _deprecate_method(*, version: str, message: Optional[str] = None):
"""Decorator to issue warnings when using a deprecated method.
Args:
version (`str`):
The version when deprecated arguments will result in error.
message (`str`, *optional*):
Warning message that is raised. If not passed, a default warning message
will be created.
"""
def _inner_deprecate_method(f):
name = f.__name__
if name == "__init__":
name = f.__qualname__.split(".")[0] # class name instead of method name
@wraps(f)
def inner_f(*args, **kwargs):
warning_message = (
f"'{name}' (from '{f.__module__}') is deprecated and will be removed from version '{version}'."
)
if message is not None:
warning_message += " " + message
warnings.warn(warning_message, FutureWarning)
return f(*args, **kwargs)
return inner_f
return _inner_deprecate_method

View File

@ -7,6 +7,7 @@ from typing import Optional
from aio_pika import connect_robust
from aio_pika.patterns import JsonRPC
from aiohttp import web
from aiormq import AMQPConnectionError
from .distributed_progress import DistributedExecutorToClientProgress
@ -24,6 +25,7 @@ class DistributedPromptWorker:
def __init__(self, embedded_comfy_client: Optional[EmbeddedComfyClient] = None,
connection_uri: str = "amqp://localhost:5672/",
queue_name: str = "comfyui",
health_check_port: int = 9090,
loop: Optional[AbstractEventLoop] = None):
self._rpc = None
self._channel = None
@ -32,6 +34,29 @@ class DistributedPromptWorker:
self._connection_uri = connection_uri
self._loop = loop or asyncio.get_event_loop()
self._embedded_comfy_client = embedded_comfy_client
self._health_check_port = health_check_port
self._health_check_site = None
async def _health_check(self, request):
return web.Response(text="OK", content_type="text/plain")
async def _start_health_check_server(self):
app = web.Application()
app.router.add_get('/health', self._health_check)
runner = web.AppRunner(app)
await runner.setup()
try:
site = web.TCPSite(runner, port=self._health_check_port)
await site.start()
self._health_check_site = site
logging.info(f"health check server started on port {self._health_check_port}")
except OSError as e:
if e.errno == 98:
logging.warning(f"port {self._health_check_port} is already in use, health check disabled but starting anyway")
else:
logging.error(f"failed to start health check server with error {str(e)}, starting anyway")
@tracer.start_as_current_span("Do Work Item")
async def _do_work_item(self, request: dict) -> dict:
@ -74,6 +99,7 @@ class DistributedPromptWorker:
await self._exit_stack.enter_async_context(self._embedded_comfy_client)
await self._rpc.register(self._queue_name, self._do_work_item)
await self._start_health_check_server()
async def __aenter__(self) -> "DistributedPromptWorker":
await self.init()
@ -83,6 +109,8 @@ class DistributedPromptWorker:
await self._rpc.close()
await self._channel.close()
await self._connection.close()
if self._health_check_site:
await self._health_check_site.stop()
async def close(self):
await self._close()

View File

@ -1,13 +1,16 @@
from __future__ import annotations
import logging
import operator
import os
from functools import reduce
from itertools import chain
from os.path import join
from typing import List, Any, Optional, Union
from pathlib import Path
from typing import List, Any, Optional, Union, Sequence
import tqdm
from huggingface_hub import hf_hub_download, scan_cache_dir
from huggingface_hub import hf_hub_download, scan_cache_dir, snapshot_download, HfFileSystem
from huggingface_hub.utils import GatedRepoError
from requests import Session
from safetensors import safe_open
@ -15,11 +18,13 @@ from safetensors.torch import save_file
from .cli_args import args
from .cmd import folder_paths
from .component_model.deprecation import _deprecate_method
from .interruption import InterruptProcessingException
from .model_downloader_types import CivitFile, HuggingFile, CivitModelsGetResponse, CivitFile_
from .utils import ProgressBar, comfy_tqdm
_session = Session()
_hf_fs = HfFileSystem()
def get_filename_list_with_downloadable(folder_name: str, known_files: List[Any]) -> List[str]:
@ -365,8 +370,121 @@ def add_known_models(folder_name: str, known_models: List[Union[CivitFile, Huggi
return known_models
@_deprecate_method(version="1.0.0", message="use get_huggingface_repo_list instead")
def huggingface_repos() -> List[str]:
cache_info = scan_cache_dir()
existing_repo_ids = frozenset(cache_item.repo_id for cache_item in cache_info.repos if cache_item.repo_type == "model")
return get_huggingface_repo_list()
def get_huggingface_repo_list(*extra_cache_dirs: str) -> List[str]:
if len(extra_cache_dirs) == 0:
extra_cache_dirs = folder_paths.get_folder_paths("huggingface_cache")
# all in cache directories
existing_repo_ids = frozenset(
cache_item.repo_id for cache_item in \
reduce(operator.or_,
map(lambda cache_info: cache_info.repos, [scan_cache_dir()] + [scan_cache_dir(cache_dir=cache_dir) for cache_dir in extra_cache_dirs if os.path.isdir(cache_dir)]))
if cache_item.repo_type == "model"
)
# also check local-dir style directories
existing_local_dir_repos = set()
local_dirs = folder_paths.get_folder_paths("huggingface")
for local_dir_root in local_dirs:
# enumerate all the two-directory paths
if not os.path.isdir(local_dir_root):
continue
for user_dir in Path(local_dir_root).iterdir():
for model_dir in user_dir.iterdir():
try:
_hf_fs.resolve_path(str(user_dir / model_dir))
except Exception as exc_info:
logging.debug(f"HuggingFaceFS did not think this was a valid repo: {user_dir.name}/{model_dir.name} with error {exc_info}", exc_info)
existing_local_dir_repos.add(f"{user_dir.name}/{model_dir.name}")
known_repo_ids = frozenset(KNOWN_HUGGINGFACE_MODEL_REPOS)
return list(existing_repo_ids | known_repo_ids)
if args.disable_known_models:
return list(existing_repo_ids | existing_local_dir_repos)
else:
return list(existing_repo_ids | existing_local_dir_repos | known_repo_ids)
def get_or_download_huggingface_repo(repo_id: str) -> Optional[str]:
cache_dirs = folder_paths.get_folder_paths("huggingface_cache")
local_dirs = folder_paths.get_folder_paths("huggingface")
cache_dirs_snapshots, local_dirs_snapshots = _get_cache_hits(cache_dirs, local_dirs, repo_id)
local_dirs_cache_hit = len(local_dirs_snapshots) > 0
cache_dirs_cache_hit = len(cache_dirs_snapshots) > 0
logging.debug(f"cache {'hit' if local_dirs_cache_hit or cache_dirs_cache_hit else 'miss'} for repo_id={repo_id} because local_dirs={local_dirs_cache_hit}, cache_dirs={cache_dirs_cache_hit}")
# if we're in forced local directory mode, only use the local dir snapshots, and otherwise, download
if args.force_hf_local_dir_mode:
# todo: we still have to figure out a way to download things to the right places by default
if len(local_dirs_snapshots) > 0:
return local_dirs_snapshots[0]
elif not args.disable_known_models:
destination = os.path.join(local_dirs[0], repo_id)
logging.debug(f"downloading repo_id={repo_id}, local_dir={destination}")
return snapshot_download(repo_id, local_dir=destination)
snapshots = local_dirs_snapshots + cache_dirs_snapshots
if len(snapshots) > 0:
return snapshots[0]
elif not args.disable_known_models:
logging.debug(f"downloading repo_id={repo_id}")
return snapshot_download(repo_id)
# this repo was not found
return None
def _get_cache_hits(cache_dirs: Sequence[str], local_dirs: Sequence[str], repo_id):
local_dirs_snapshots = []
cache_dirs_snapshots = []
# find all the pre-existing downloads for this repo_id
try:
repo_files = set(_hf_fs.ls(repo_id, detail=False))
except:
repo_files = []
if len(repo_files) > 0:
for local_dir in local_dirs:
local_path = Path(local_dir) / repo_id
local_files = set(f"{repo_id}/{f.relative_to(local_path)}" for f in local_path.rglob("*") if f.is_file())
# fix path representation
local_files = set(f.replace("\\", "/") for f in local_files)
# remove .huggingface
local_files = set(f for f in local_files if not f.startswith(f"{repo_id}/.huggingface"))
# local_files.issubsetof(repo_files)
if local_files.issubset(repo_files):
local_dirs_snapshots.append(str(local_path))
else:
# an empty repository or unknown repository info, trust that if the directory exists, it matches
for local_dir in local_dirs:
local_path = Path(local_dir) / repo_id
if local_path.is_dir():
local_dirs_snapshots.append(str(local_path))
for cache_dir in (None, *cache_dirs):
try:
cache_dirs_snapshots.append(snapshot_download(repo_id, local_files_only=True, cache_dir=cache_dir))
except FileNotFoundError:
continue
except:
continue
return cache_dirs_snapshots, local_dirs_snapshots
def _delete_repo_from_huggingface_cache(repo_id: str, cache_dir: Optional[str] = None) -> List[str]:
results = scan_cache_dir(cache_dir)
matching = [repo for repo in results.repos if repo.repo_id == repo_id]
if len(matching) == 0:
return []
revisions: List[str] = []
for repo in matching:
for revision_info in repo.revisions:
revisions.append(revision_info.commit_hash)
results.delete_revisions(*revisions).execute()
return revisions

View File

@ -43,7 +43,6 @@ class HuggingFile:
Attributes:
repo_id (str): The Huggingface repository of a known file
filename (str): The path to the known file in the repository
show_in_ui (bool): Not used. Will indicate whether or not the file should be shown in the UI to reduce clutter
"""
repo_id: str
filename: str

View File

@ -26,7 +26,7 @@ from ..cli_args import args
from ..cmd import folder_paths, latent_preview
from ..execution_context import current_execution_context
from ..images import open_image
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, huggingface_repos, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
from ..model_downloader import get_filename_list_with_downloadable, get_or_download, KNOWN_CHECKPOINTS, KNOWN_CLIP_VISION_MODELS, KNOWN_GLIGEN_MODELS, KNOWN_UNCLIP_CHECKPOINTS, KNOWN_LORAS, KNOWN_CONTROLNETS, KNOWN_DIFF_CONTROLNETS, KNOWN_VAES, KNOWN_APPROX_VAES, get_huggingface_repo_list, KNOWN_CLIP_MODELS, KNOWN_UNET_MODELS
from ..nodes.common import MAX_RESOLUTION
from .. import controlnet
from ..open_exr import load_exr
@ -517,7 +517,7 @@ class DiffusersLoader:
if "model_index.json" in files:
paths.append(os.path.relpath(root, start=search_path))
paths += huggingface_repos()
paths += get_huggingface_repo_list()
paths = list(frozenset(paths))
return {"required": {"model_path": (paths,), }}

View File

@ -9,15 +9,15 @@ from functools import reduce
from typing import Any, Dict, Optional, List, Callable, Union
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
from transformers import AutoTokenizer, PreTrainedModel, LogitsProcessor, TextStreamer, \
PreTrainedTokenizerBase, LogitsProcessorList, PretrainedConfig, AutoProcessor, BatchFeature, ProcessorMixin, \
LlavaNextForConditionalGeneration, LlavaNextProcessor, T5EncoderModel, AutoModel
LlavaNextForConditionalGeneration, LlavaNextProcessor, AutoModel
from typing_extensions import TypedDict
from comfy.language.chat_templates import KNOWN_CHAT_TEMPLATES
from comfy.language.language_types import ProcessorResult
from comfy.language.transformers_model_management import TransformersManagedModel
from comfy.model_downloader import huggingface_repos
from comfy.model_downloader import get_huggingface_repo_list, get_or_download_huggingface_repo
from comfy.model_management import get_torch_device_name, load_model_gpu, unet_dtype, unet_offload_device
from comfy.nodes.package_typing import CustomNode, InputTypes, ValidatedNodeResult
from comfy.utils import comfy_tqdm, seed_for_block, comfy_progress, ProgressBar
@ -197,7 +197,7 @@ class TransformersImageProcessorLoader(CustomNode):
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (huggingface_repos(),),
"ckpt_name": (get_huggingface_repo_list(),),
"subfolder": ("STRING", {}),
"model": ("MODEL", {}),
"overwrite_tokenizer": ("BOOLEAN", {"default": False}),
@ -212,6 +212,7 @@ class TransformersImageProcessorLoader(CustomNode):
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
processor = AutoProcessor.from_pretrained(ckpt_name, torch_dtype=unet_dtype(), device_map=get_torch_device_name(unet_offload_device()), low_cpu_mem_usage=True, trust_remote_code=True, **hub_kwargs)
return model.patch_processor(processor, overwrite_tokenizer),
@ -221,7 +222,7 @@ class TransformersLoader(CustomNode):
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"ckpt_name": (huggingface_repos(),),
"ckpt_name": (get_huggingface_repo_list(),),
"subfolder": ("STRING", {})
},
}
@ -234,6 +235,8 @@ class TransformersLoader(CustomNode):
hub_kwargs = {}
if subfolder is not None and subfolder != "":
hub_kwargs["subfolder"] = subfolder
ckpt_name = get_or_download_huggingface_repo(ckpt_name)
with comfy_tqdm():
from_pretrained_kwargs = {
"pretrained_model_name_or_path": ckpt_name,

View File

@ -112,7 +112,7 @@ class FloatRequestParameter(CustomNode):
def INPUT_TYPES(cls) -> InputTypes:
return {
"required": {
"value": ("FLOAT", {"default": 0})
"value": ("FLOAT", {"default": 0, "step": 0.00001, "round": 0.00001})
},
"optional": {
**_open_api_common_schema,

View File

@ -4,6 +4,7 @@ from concurrent.futures import ThreadPoolExecutor
import jwt
import pytest
from aiohttp import ClientSession, ClientConnectorError
from testcontainers.rabbitmq import RabbitMqContainer
from comfy.client.aio_client import AsyncRemoteComfyClient
@ -108,3 +109,97 @@ async def test_frontend_backend_workers_validation_error_raises(frontend_backend
prompt = sdxl_workflow_with_refiner("test", inference_steps=1, refiner_steps=1, sdxl_refiner_checkpoint_name="unknown.safetensors")
with pytest.raises(Exception):
await client.queue_prompt(prompt)
async def check_health(url: str, max_retries: int = 5, retry_delay: float = 1.0):
async with ClientSession() as session:
for _ in range(max_retries):
try:
async with session.get(url, timeout=1) as response:
if response.status == 200 and await response.text() == "OK":
return True
except Exception as exc_info:
pass
await asyncio.sleep(retry_delay)
return False
@pytest.mark.asyncio
async def test_basic_queue_worker_with_health_check():
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
health_check_port = 9090
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
# Test health check
health_check_url = f"http://localhost:{health_check_port}/health"
health_check_ok = await check_health(health_check_url)
assert health_check_ok, "Health check server did not start properly"
# Test the actual worker functionality
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
await distributed_queue.init()
queue_item = create_test_prompt()
res = await distributed_queue.put_async(queue_item)
assert res.item_id == queue_item.prompt_id
assert len(res.outputs) == 1
assert res.status is not None
assert res.status.status_str == "success"
await distributed_queue.close()
# Test that the health check server is stopped after the worker is closed
health_check_stopped = not await check_health(health_check_url, max_retries=1)
assert health_check_stopped, "Health check server did not stop properly"
@pytest.mark.asyncio
async def test_health_check_port_conflict():
with RabbitMqContainer("rabbitmq:latest") as rabbitmq:
params = rabbitmq.get_connection_params()
connection_uri = f"amqp://guest:guest@127.0.0.1:{params.port}"
health_check_port = 9090
# Start a simple server to occupy the health check port
from aiohttp import web
async def dummy_handler(request):
return web.Response(text="Dummy")
app = web.Application()
app.router.add_get('/', dummy_handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, '0.0.0.0', health_check_port)
await site.start()
try:
# Now try to start the DistributedPromptWorker
async with DistributedPromptWorker(connection_uri=connection_uri, health_check_port=health_check_port) as worker:
# The health check should be disabled, but the worker should still function
from comfy.distributed.distributed_prompt_queue import DistributedPromptQueue
distributed_queue = DistributedPromptQueue(ServerStub(), is_callee=False, is_caller=True, connection_uri=connection_uri)
await distributed_queue.init()
queue_item = create_test_prompt()
res = await distributed_queue.put_async(queue_item)
assert res.item_id == queue_item.prompt_id
assert len(res.outputs) == 1
assert res.status is not None
assert res.status.status_str == "success"
await distributed_queue.close()
# The original server should still be running
async with ClientSession() as session:
async with session.get(f"http://localhost:{health_check_port}") as response:
assert response.status == 200
assert await response.text() == "Dummy"
finally:
await runner.cleanup()

View File

View File

@ -0,0 +1,152 @@
import os
import shutil
import pytest
from comfy.cli_args import args
from comfy.cmd import folder_paths
from comfy.cmd.folder_paths import FolderPathsTuple
from comfy.model_downloader import KNOWN_HUGGINGFACE_MODEL_REPOS, get_huggingface_repo_list, \
get_or_download_huggingface_repo, _get_cache_hits, _delete_repo_from_huggingface_cache
_gitattributes = """*.7z filter=lfs diff=lfs merge=lfs -text
*.arrow filter=lfs diff=lfs merge=lfs -text
*.bin filter=lfs diff=lfs merge=lfs -text
*.bz2 filter=lfs diff=lfs merge=lfs -text
*.ckpt filter=lfs diff=lfs merge=lfs -text
*.ftz filter=lfs diff=lfs merge=lfs -text
*.gz filter=lfs diff=lfs merge=lfs -text
*.h5 filter=lfs diff=lfs merge=lfs -text
*.joblib filter=lfs diff=lfs merge=lfs -text
*.lfs.* filter=lfs diff=lfs merge=lfs -text
*.mlmodel filter=lfs diff=lfs merge=lfs -text
*.model filter=lfs diff=lfs merge=lfs -text
*.msgpack filter=lfs diff=lfs merge=lfs -text
*.npy filter=lfs diff=lfs merge=lfs -text
*.npz filter=lfs diff=lfs merge=lfs -text
*.onnx filter=lfs diff=lfs merge=lfs -text
*.ot filter=lfs diff=lfs merge=lfs -text
*.parquet filter=lfs diff=lfs merge=lfs -text
*.pb filter=lfs diff=lfs merge=lfs -text
*.pickle filter=lfs diff=lfs merge=lfs -text
*.pkl filter=lfs diff=lfs merge=lfs -text
*.pt filter=lfs diff=lfs merge=lfs -text
*.pth filter=lfs diff=lfs merge=lfs -text
*.rar filter=lfs diff=lfs merge=lfs -text
*.safetensors filter=lfs diff=lfs merge=lfs -text
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
*.tar.* filter=lfs diff=lfs merge=lfs -text
*.tar filter=lfs diff=lfs merge=lfs -text
*.tflite filter=lfs diff=lfs merge=lfs -text
*.tgz filter=lfs diff=lfs merge=lfs -text
*.wasm filter=lfs diff=lfs merge=lfs -text
*.xz filter=lfs diff=lfs merge=lfs -text
*.zip filter=lfs diff=lfs merge=lfs -text
*.zst filter=lfs diff=lfs merge=lfs -text
*tfevents* filter=lfs diff=lfs merge=lfs -text
"""
@pytest.mark.asyncio
def test_known_repos(tmp_path_factory):
test_cache_dir = tmp_path_factory.mktemp("huggingface_cache")
test_local_dir = tmp_path_factory.mktemp("huggingface_locals")
test_repo_id = "doctorpangloss/comfyui_downloader_test"
prev_huggingface = folder_paths.folder_names_and_paths["huggingface"]
prev_huggingface_cache = folder_paths.folder_names_and_paths["huggingface_cache"]
prev_hub_cache = os.getenv("HF_HUB_CACHE")
_delete_repo_from_huggingface_cache(test_repo_id)
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
try:
folder_paths.folder_names_and_paths["huggingface"] = FolderPathsTuple("huggingface", [test_local_dir], {""})
folder_paths.folder_names_and_paths["huggingface_cache"] = FolderPathsTuple("huggingface_cache", [test_cache_dir], {""})
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
assert len(cache_hits) == len(locals_hits) == 0, "not downloaded yet"
# test downloading the repo and observing a cache hit on second access
existing_repos = get_huggingface_repo_list()
assert test_repo_id not in existing_repos
KNOWN_HUGGINGFACE_MODEL_REPOS.add(test_repo_id)
existing_repos = get_huggingface_repo_list()
assert test_repo_id in existing_repos
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
assert len(cache_hits) == len(locals_hits) == 0, "not downloaded yet"
# download to cache
path = get_or_download_huggingface_repo(test_repo_id)
assert path is not None
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
assert len(cache_hits) == 1, "should have downloaded to cache"
assert len(locals_hits) == 0, "should not have downloaded to a local dir"
# load from cache
args.disable_known_models = True
path = get_or_download_huggingface_repo(test_repo_id)
assert path is not None, "should have used local path"
# test deleting from cache
_delete_repo_from_huggingface_cache(test_repo_id)
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
assert len(cache_hits) == 0, "should have deleted from the cache"
assert len(locals_hits) == 0, "should not have downloaded to a local dir"
# test fails to download
path = get_or_download_huggingface_repo(test_repo_id)
assert path is None, "should not have downloaded since disable_known_models is True"
args.disable_known_models = False
# download to local dir
args.force_hf_local_dir_mode = True
path = get_or_download_huggingface_repo(test_repo_id)
assert path is not None
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
assert len(cache_hits) == 0
assert len(locals_hits) == 1, "should have downloaded to local dir"
# test loads from local dir
args.disable_known_models = True
path = get_or_download_huggingface_repo(test_repo_id)
assert path is not None
# test deleting local dir
expected_path = os.path.join(test_local_dir, test_repo_id)
shutil.rmtree(expected_path)
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
assert len(cache_hits) == 0
assert len(locals_hits) == 0
path = get_or_download_huggingface_repo(test_repo_id)
assert path is None, "should not download repo into local dir"
# recreating the test repo should be valid
os.makedirs(expected_path)
with open(os.path.join(expected_path, "test.txt"), "wt") as f:
f.write("OK")
with open(os.path.join(expected_path, ".gitattributes"), "wt") as f:
f.write(_gitattributes)
args.disable_known_models = False
# expect local hit
cache_hits, locals_hits = _get_cache_hits([test_cache_dir], [test_local_dir], test_repo_id)
assert len(cache_hits) == 0
assert len(locals_hits) == 1
# should not download
path = get_or_download_huggingface_repo(test_repo_id)
assert path is not None
finally:
_delete_repo_from_huggingface_cache(test_repo_id)
_delete_repo_from_huggingface_cache(test_repo_id, test_cache_dir)
KNOWN_HUGGINGFACE_MODEL_REPOS.remove(test_repo_id)
folder_paths.folder_names_and_paths["huggingface"] = prev_huggingface
folder_paths.folder_names_and_paths["huggingface_cache"] = prev_huggingface_cache
if prev_hub_cache is None and "HF_HUB_CACHE" in os.environ:
os.environ.pop("HF_HUB_CACHE")
elif prev_hub_cache is not None:
os.environ["HF_HUB_CACHE"] = prev_hub_cache
args.force_hf_local_dir_mode = False
args.disable_known_models = False