mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-01-10 06:10:50 +08:00
80 lines
2.3 KiB
Python
80 lines
2.3 KiB
Python
import threading
|
|
import time
|
|
|
|
import pytest
|
|
|
|
from comfy.component_model.hf_hub_download_with_disable_xet import hf_hub_download_with_retries
|
|
from comfy.component_model.tqdm_watcher import TqdmWatcher
|
|
|
|
download_method_name = "comfy.component_model.hf_hub_download_with_disable_xet.hf_hub_download_with_disable_fast"
|
|
|
|
def mock_stalled_download(*args, **kwargs):
|
|
"""A mock for hf_hub_download that simulates a stall by sleeping indefinitely."""
|
|
time.sleep(10)
|
|
return "this_path_should_never_be_returned"
|
|
|
|
|
|
def test_download_stalls_and_fails(monkeypatch):
|
|
"""
|
|
Verify that a stalled download triggers retries and eventually fails with an RuntimeError.
|
|
"""
|
|
|
|
monkeypatch.setattr(download_method_name, mock_stalled_download)
|
|
watcher = TqdmWatcher()
|
|
repo_id = "test/repo-stall"
|
|
filename = "stalled_file.safetensors"
|
|
|
|
with pytest.raises(RuntimeError) as excinfo:
|
|
hf_hub_download_with_retries(
|
|
repo_id=repo_id,
|
|
filename=filename,
|
|
watcher=watcher,
|
|
stall_timeout=0.2,
|
|
retries=2,
|
|
)
|
|
|
|
assert f"Failed to download '{repo_id}/{filename}' after 2 attempts" in str(excinfo.value)
|
|
|
|
|
|
def mock_successful_slow_download(*args, **kwargs):
|
|
"""A mock for a download that is slow but not stalled."""
|
|
time.sleep(1)
|
|
|
|
return "expected/successful/path"
|
|
|
|
|
|
def _keep_watcher_alive(watcher: TqdmWatcher, stop_event: threading.Event):
|
|
"""Helper function to run in a thread and periodically tick the watcher."""
|
|
while not stop_event.is_set():
|
|
watcher.tick()
|
|
time.sleep(0.1)
|
|
|
|
|
|
def test_download_progresses_and_succeeds(monkeypatch):
|
|
"""
|
|
Verify that a download with periodic progress updates completes successfully.
|
|
"""
|
|
monkeypatch.setattr(download_method_name, mock_successful_slow_download)
|
|
|
|
watcher = TqdmWatcher()
|
|
stop_event = threading.Event()
|
|
ticker_thread = threading.Thread(
|
|
target=_keep_watcher_alive,
|
|
args=(watcher, stop_event),
|
|
daemon=True
|
|
|
|
)
|
|
ticker_thread.start()
|
|
|
|
try:
|
|
result = hf_hub_download_with_retries(
|
|
repo_id="test/repo-success",
|
|
filename="good_file.safetensors",
|
|
stall_timeout=0.3,
|
|
watcher=watcher
|
|
)
|
|
assert result == "expected/successful/path"
|
|
finally:
|
|
stop_event.set()
|
|
ticker_thread.join(timeout=1)
|