ComfyUI/comfy/component_model/hf_hub_download_with_disable_xet.py
2025-08-22 17:29:18 -07:00

124 lines
4.8 KiB
Python

from __future__ import annotations
import logging
import os
import platform
import time
from concurrent.futures import Future
from pathlib import Path
from typing import Optional
import filelock
import huggingface_hub
from huggingface_hub import hf_hub_download
from huggingface_hub import logging as hf_logging
hf_logging.set_verbosity_debug()
from pebble import ThreadPool
from .tqdm_watcher import TqdmWatcher
logger = logging.getLogger(__name__)
_VAR = "HF_HUB_ENABLE_HF_TRANSFER"
_XET_VAR = "HF_XET_HIGH_PERFORMANCE"
if platform.system() == "Windows":
os.environ["HF_HUB_DISABLE_XET"] = "1"
logger.debug("Xet was disabled since it is currently not reliable")
os.environ[_VAR] = "True"
else:
os.environ[_XET_VAR] = "True"
def hf_hub_download_with_disable_fast(repo_id=None, filename=None, disable_fast=None, hf_env: dict[str, str] = None, **kwargs):
for k, v in hf_env.items():
os.environ[k] = v
if disable_fast:
if _VAR == _XET_VAR:
os.environ["HF_HUB_DISABLE_XET"] = "1"
else:
os.environ[_VAR] = "False"
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
def hf_hub_download_with_retries(repo_id: str, filename: str, watcher: Optional[TqdmWatcher] = None, retries=2, stall_timeout=10, **kwargs):
"""
Wraps hf_hub_download with stall detection and retries using a TqdmWatcher.
Includes a monkey-patch for filelock to release locks from stalled downloads.
"""
if watcher is None:
logger.warning(f"called _hf_hub_download_with_retries without progress to watch")
return hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
xet_available = huggingface_hub.file_download.is_xet_available()
hf_hub_disable_xet_prev_value = os.getenv("HF_HUB_DISABLE_XET")
disable_fast = hf_hub_disable_xet_prev_value is not None
instantiated_locks: list[filelock.FileLock] = []
original_filelock_init = filelock.FileLock.__init__
def new_filelock_init(self, *args, **kwargs):
"""A wrapper around FileLock.__init__ to capture lock instances."""
original_filelock_init(self, *args, **kwargs)
instantiated_locks.append(self)
filelock.FileLock.__init__ = new_filelock_init
try:
with ThreadPool(max_workers=retries + 1) as executor:
for attempt in range(retries):
watcher.tick()
hf_env = {k: v for k, v in os.environ.items() if k.upper().startswith("HF_")}
if len(instantiated_locks) > 0:
logger.debug(f"Attempting to unlock {len(instantiated_locks)} captured file locks.")
for lock in instantiated_locks:
path = lock.lock_file
if lock.is_locked:
lock.release(force=True)
else:
# something else went wrong
try:
lock._release()
except (AttributeError, TypeError):
pass
try:
Path(path).unlink(missing_ok=True)
except OSError:
# todo: obviously the process is holding this lock
pass
logger.debug(f"Released stalled lock: {lock.lock_file}")
instantiated_locks.clear()
future: Future[str] = executor.submit(hf_hub_download_with_disable_fast, repo_id=repo_id, filename=filename, disable_fast=disable_fast, hf_env=hf_env, **kwargs)
try:
while not future.done():
if time.monotonic() - watcher.last_update_time > stall_timeout:
msg = f"Download of '{repo_id}/{filename}' stalled for >{stall_timeout}s. Retrying... (Attempt {attempt + 1}/{retries})"
if xet_available:
logger.warning(f"{msg}. Disabling xet for our retry.")
disable_fast = True
else:
logger.warning(msg)
future.cancel() # Cancel the stalled future
break
time.sleep(0.5)
if future.done() and not future.cancelled():
return future.result()
except Exception as e:
logger.error(f"Exception during download attempt {attempt + 1}: {e}", exc_info=True)
raise RuntimeError(f"Failed to download '{repo_id}/{filename}' after {retries} attempts.")
finally:
filelock.FileLock.__init__ = original_filelock_init
if hf_hub_disable_xet_prev_value is not None:
os.environ["HF_HUB_DISABLE_XET"] = hf_hub_disable_xet_prev_value