ComfyUI/tests/unit/test_download_bailout.py
2025-08-07 17:29:23 -07:00

80 lines
2.3 KiB
Python

import threading
import time
import pytest
from comfy.component_model.hf_hub_download_with_disable_xet import hf_hub_download_with_retries
from comfy.component_model.tqdm_watcher import TqdmWatcher
download_method_name = "comfy.component_model.hf_hub_download_with_disable_xet.hf_hub_download_with_disable_fast"
def mock_stalled_download(*args, **kwargs):
"""A mock for hf_hub_download that simulates a stall by sleeping indefinitely."""
time.sleep(10)
return "this_path_should_never_be_returned"
def test_download_stalls_and_fails(monkeypatch):
"""
Verify that a stalled download triggers retries and eventually fails with an RuntimeError.
"""
monkeypatch.setattr(download_method_name, mock_stalled_download)
watcher = TqdmWatcher()
repo_id = "test/repo-stall"
filename = "stalled_file.safetensors"
with pytest.raises(RuntimeError) as excinfo:
hf_hub_download_with_retries(
repo_id=repo_id,
filename=filename,
watcher=watcher,
stall_timeout=0.2,
retries=2,
)
assert f"Failed to download '{repo_id}/{filename}' after 2 attempts" in str(excinfo.value)
def mock_successful_slow_download(*args, **kwargs):
"""A mock for a download that is slow but not stalled."""
time.sleep(1)
return "expected/successful/path"
def _keep_watcher_alive(watcher: TqdmWatcher, stop_event: threading.Event):
"""Helper function to run in a thread and periodically tick the watcher."""
while not stop_event.is_set():
watcher.tick()
time.sleep(0.1)
def test_download_progresses_and_succeeds(monkeypatch):
"""
Verify that a download with periodic progress updates completes successfully.
"""
monkeypatch.setattr(download_method_name, mock_successful_slow_download)
watcher = TqdmWatcher()
stop_event = threading.Event()
ticker_thread = threading.Thread(
target=_keep_watcher_alive,
args=(watcher, stop_event),
daemon=True
)
ticker_thread.start()
try:
result = hf_hub_download_with_retries(
repo_id="test/repo-success",
filename="good_file.safetensors",
stall_timeout=0.3,
watcher=watcher
)
assert result == "expected/successful/path"
finally:
stop_event.set()
ticker_thread.join(timeout=1)