mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-11 06:40:48 +08:00
124 lines
4.8 KiB
Python
124 lines
4.8 KiB
Python
from __future__ import annotations
|
|
|
|
import logging
|
|
import os
|
|
import platform
|
|
import time
|
|
from concurrent.futures import Future
|
|
from pathlib import Path
|
|
from typing import Optional
|
|
|
|
import filelock
|
|
import huggingface_hub
|
|
from huggingface_hub import hf_hub_download
|
|
from huggingface_hub import logging as hf_logging
|
|
|
|
hf_logging.set_verbosity_debug()
|
|
from pebble import ThreadPool
|
|
|
|
from .tqdm_watcher import TqdmWatcher
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
_VAR = "HF_HUB_ENABLE_HF_TRANSFER"
|
|
_XET_VAR = "HF_XET_HIGH_PERFORMANCE"
|
|
|
|
|
|
if platform.system() == "Windows":
|
|
os.environ["HF_HUB_DISABLE_XET"] = "1"
|
|
logger.debug("Xet was disabled since it is currently not reliable")
|
|
os.environ[_VAR] = "True"
|
|
else:
|
|
os.environ[_XET_VAR] = "True"
|
|
|
|
|
|
|
|
def hf_hub_download_with_disable_fast(repo_id=None, filename=None, disable_fast=None, hf_env: dict[str, str] = None, **kwargs):
|
|
for k, v in hf_env.items():
|
|
os.environ[k] = v
|
|
if disable_fast:
|
|
if _VAR == _XET_VAR:
|
|
os.environ["HF_HUB_DISABLE_XET"] = "1"
|
|
else:
|
|
os.environ[_VAR] = "False"
|
|
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
|
|
|
|
|
|
def hf_hub_download_with_retries(repo_id: str, filename: str, watcher: Optional[TqdmWatcher] = None, retries=2, stall_timeout=10, **kwargs):
|
|
"""
|
|
Wraps hf_hub_download with stall detection and retries using a TqdmWatcher.
|
|
Includes a monkey-patch for filelock to release locks from stalled downloads.
|
|
"""
|
|
if watcher is None:
|
|
logger.warning(f"called _hf_hub_download_with_retries without progress to watch")
|
|
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
|
|
|
|
xet_available = huggingface_hub.file_download.is_xet_available()
|
|
hf_hub_disable_xet_prev_value = os.getenv("HF_HUB_DISABLE_XET")
|
|
disable_fast = hf_hub_disable_xet_prev_value is not None
|
|
|
|
instantiated_locks: list[filelock.FileLock] = []
|
|
original_filelock_init = filelock.FileLock.__init__
|
|
|
|
def new_filelock_init(self, *args, **kwargs):
|
|
"""A wrapper around FileLock.__init__ to capture lock instances."""
|
|
original_filelock_init(self, *args, **kwargs)
|
|
instantiated_locks.append(self)
|
|
|
|
filelock.FileLock.__init__ = new_filelock_init
|
|
|
|
try:
|
|
with ThreadPool(max_workers=retries + 1) as executor:
|
|
for attempt in range(retries):
|
|
watcher.tick()
|
|
hf_env = {k: v for k, v in os.environ.items() if k.upper().startswith("HF_")}
|
|
|
|
if len(instantiated_locks) > 0:
|
|
logger.debug(f"Attempting to unlock {len(instantiated_locks)} captured file locks.")
|
|
for lock in instantiated_locks:
|
|
path = lock.lock_file
|
|
if lock.is_locked:
|
|
lock.release(force=True)
|
|
else:
|
|
# something else went wrong
|
|
try:
|
|
lock._release()
|
|
except (AttributeError, TypeError):
|
|
pass
|
|
try:
|
|
Path(path).unlink(missing_ok=True)
|
|
except OSError:
|
|
# todo: obviously the process is holding this lock
|
|
pass
|
|
logger.debug(f"Released stalled lock: {lock.lock_file}")
|
|
instantiated_locks.clear()
|
|
future: Future[str] = executor.submit(hf_hub_download_with_disable_fast, repo_id=repo_id, filename=filename, disable_fast=disable_fast, hf_env=hf_env, **kwargs)
|
|
|
|
try:
|
|
while not future.done():
|
|
if time.monotonic() - watcher.last_update_time > stall_timeout:
|
|
msg = f"Download of '{repo_id}/{filename}' stalled for >{stall_timeout}s. Retrying... (Attempt {attempt + 1}/{retries})"
|
|
if xet_available:
|
|
logger.warning(f"{msg}. Disabling xet for our retry.")
|
|
disable_fast = True
|
|
else:
|
|
logger.warning(msg)
|
|
|
|
future.cancel() # Cancel the stalled future
|
|
break
|
|
|
|
time.sleep(0.5)
|
|
|
|
if future.done() and not future.cancelled():
|
|
return future.result()
|
|
|
|
except Exception as e:
|
|
logger.error(f"Exception during download attempt {attempt + 1}: {e}", exc_info=True)
|
|
|
|
raise RuntimeError(f"Failed to download '{repo_id}/{filename}' after {retries} attempts.")
|
|
finally:
|
|
filelock.FileLock.__init__ = original_filelock_init
|
|
|
|
if hf_hub_disable_xet_prev_value is not None:
|
|
os.environ["HF_HUB_DISABLE_XET"] = hf_hub_disable_xet_prev_value
|