mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-24 00:39:30 +08:00
test(model_downloader): unit tests for downloads, allowlist, paths, and HF auth
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
parent
82ccce7d09
commit
dcd61f1132
567
tests-unit/app_test/hf_auth_test.py
Normal file
567
tests-unit/app_test/hf_auth_test.py
Normal file
@ -0,0 +1,567 @@
|
|||||||
|
"""Unit tests for the HuggingFace auth subsystem.
|
||||||
|
|
||||||
|
Covers:
|
||||||
|
- token store: save/load roundtrip, chmod 0600, atomic write, delete
|
||||||
|
- eligibility under various CLI-arg combinations
|
||||||
|
- URL parsing (huggingface.co host detection + repo_id extraction)
|
||||||
|
- HF-aware gated_detection.probe_url (mocked auth_check)
|
||||||
|
- HF auth routes (token status, login start with eligibility gate, logout)
|
||||||
|
- PKCE primitives + authorize URL shape
|
||||||
|
|
||||||
|
The OAuth callback server itself isn't exercised end-to-end here — that
|
||||||
|
requires a real HF server. We test the components (state checking,
|
||||||
|
URL building, code-exchange request shape) instead.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import stat
|
||||||
|
from unittest.mock import AsyncMock, MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from app.model_downloader.api.routes import register_routes
|
||||||
|
from app.model_downloader.hf_auth import oauth
|
||||||
|
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE, HfAuthStore
|
||||||
|
from app.model_downloader.hf_auth.token_store import (
|
||||||
|
EXPIRY_BUFFER_SECS,
|
||||||
|
Token,
|
||||||
|
delete_token,
|
||||||
|
load_token,
|
||||||
|
save_token,
|
||||||
|
)
|
||||||
|
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
|
||||||
|
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Fixtures
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patched_user_dir(tmp_path):
|
||||||
|
"""Redirect ``folder_paths.get_user_directory`` so the token file
|
||||||
|
lands in an isolated tmp_path instead of the real user dir."""
|
||||||
|
user_dir = tmp_path / "user"
|
||||||
|
user_dir.mkdir()
|
||||||
|
with patch("folder_paths.get_user_directory", return_value=str(user_dir)):
|
||||||
|
yield user_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fresh_auth_store():
|
||||||
|
"""Wipe singleton state between tests: auth + probe caches."""
|
||||||
|
from app.model_downloader import gated_detection
|
||||||
|
|
||||||
|
HF_AUTH_STORE._token = None
|
||||||
|
HF_AUTH_STORE._loaded_from_disk = False
|
||||||
|
gated_detection.clear_caches_for_tests()
|
||||||
|
yield HF_AUTH_STORE
|
||||||
|
HF_AUTH_STORE._token = None
|
||||||
|
HF_AUTH_STORE._loaded_from_disk = False
|
||||||
|
gated_detection.clear_caches_for_tests()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(patched_user_dir, fresh_auth_store):
|
||||||
|
app = web.Application()
|
||||||
|
register_routes(app)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# URL parsing
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_hf_url_recognises_huggingface_co():
|
||||||
|
assert is_hf_url("https://huggingface.co/x/y/resolve/main/z.safetensors")
|
||||||
|
assert is_hf_url("https://huggingface.co/abc")
|
||||||
|
assert not is_hf_url("https://hf-mirror.com/x/y/resolve/main/z.safetensors")
|
||||||
|
assert not is_hf_url("https://civitai.com/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_repo_id_from_url_extracts_org_and_repo():
|
||||||
|
url = "https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-HDR/resolve/main/x.safetensors"
|
||||||
|
assert repo_id_from_url(url) == "Lightricks/LTX-2.3-22b-IC-LoRA-HDR"
|
||||||
|
|
||||||
|
|
||||||
|
def test_repo_id_from_url_handles_nested_path():
|
||||||
|
url = "https://huggingface.co/Comfy-Org/ltx-2.3/resolve/main/split_files/loras/x.safetensors"
|
||||||
|
assert repo_id_from_url(url) == "Comfy-Org/ltx-2.3"
|
||||||
|
|
||||||
|
|
||||||
|
def test_repo_id_from_url_returns_none_for_non_hf():
|
||||||
|
assert repo_id_from_url("https://civitai.com/x.safetensors") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_repo_id_from_url_returns_none_for_non_resolve_paths():
|
||||||
|
assert repo_id_from_url("https://huggingface.co/org/repo/blob/main/x.safetensors") is None
|
||||||
|
assert repo_id_from_url("https://huggingface.co/org") is None
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Token store
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_store_roundtrip(patched_user_dir):
|
||||||
|
tok = Token(
|
||||||
|
access_token="hf_abc",
|
||||||
|
refresh_token="rf_def",
|
||||||
|
expires_at=9999999999.0,
|
||||||
|
scope="openid profile",
|
||||||
|
)
|
||||||
|
save_token(tok)
|
||||||
|
loaded = load_token()
|
||||||
|
assert loaded == tok
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_store_writes_0600(patched_user_dir):
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
|
||||||
|
save_token(tok)
|
||||||
|
path = os.path.join(patched_user_dir, "hf_auth_token.json")
|
||||||
|
mode = stat.S_IMODE(os.stat(path).st_mode)
|
||||||
|
# On Windows we silently no-op chmod; allow either the intended
|
||||||
|
# mode or whatever umask the OS gave us.
|
||||||
|
if os.name == "posix":
|
||||||
|
assert mode == 0o600
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_store_delete_removes_file(patched_user_dir):
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
|
||||||
|
save_token(tok)
|
||||||
|
delete_token()
|
||||||
|
path = os.path.join(patched_user_dir, "hf_auth_token.json")
|
||||||
|
assert not os.path.exists(path)
|
||||||
|
# Idempotent: second delete is fine.
|
||||||
|
delete_token()
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_store_load_returns_none_for_missing_file(patched_user_dir):
|
||||||
|
assert load_token() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_store_load_returns_none_for_corrupt_file(patched_user_dir):
|
||||||
|
path = os.path.join(patched_user_dir, "hf_auth_token.json")
|
||||||
|
with open(path, "w") as f:
|
||||||
|
f.write("not json {")
|
||||||
|
assert load_token() is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_token_is_valid_uses_buffer(patched_user_dir):
|
||||||
|
import time
|
||||||
|
|
||||||
|
fresh = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
|
||||||
|
nearly_expired = Token(
|
||||||
|
access_token="x",
|
||||||
|
refresh_token=None,
|
||||||
|
expires_at=time.time() + EXPIRY_BUFFER_SECS - 1,
|
||||||
|
)
|
||||||
|
assert fresh.is_valid()
|
||||||
|
assert not nearly_expired.is_valid()
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Auth store
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_store_loads_lazily(patched_user_dir):
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
|
||||||
|
save_token(tok)
|
||||||
|
store = HfAuthStore()
|
||||||
|
assert store.has_token()
|
||||||
|
assert store.get_token_sync() == tok
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_store_set_persists(patched_user_dir):
|
||||||
|
store = HfAuthStore()
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
|
||||||
|
store.set_token(tok)
|
||||||
|
# Token is on disk now — a fresh store sees it.
|
||||||
|
assert HfAuthStore().get_token_sync() == tok
|
||||||
|
|
||||||
|
|
||||||
|
def test_auth_store_clear_removes_in_memory_and_on_disk(patched_user_dir):
|
||||||
|
store = HfAuthStore()
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
|
||||||
|
store.set_token(tok)
|
||||||
|
store.clear()
|
||||||
|
assert not store.has_token()
|
||||||
|
assert HfAuthStore().get_token_sync() is None
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_store_get_valid_returns_fresh_token(patched_user_dir):
|
||||||
|
store = HfAuthStore()
|
||||||
|
import time
|
||||||
|
|
||||||
|
tok = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
|
||||||
|
store.set_token(tok)
|
||||||
|
fetched = await store.get_valid_token()
|
||||||
|
assert fetched == tok
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_store_get_valid_refresh_on_expired(patched_user_dir):
|
||||||
|
store = HfAuthStore()
|
||||||
|
import time
|
||||||
|
|
||||||
|
expired = Token(
|
||||||
|
access_token="old",
|
||||||
|
refresh_token="rf",
|
||||||
|
expires_at=time.time() - 100,
|
||||||
|
)
|
||||||
|
store.set_token(expired)
|
||||||
|
refreshed = Token(
|
||||||
|
access_token="new",
|
||||||
|
refresh_token="rf",
|
||||||
|
expires_at=time.time() + 3600,
|
||||||
|
)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.hf_auth.oauth.refresh_access_token",
|
||||||
|
new=AsyncMock(return_value=refreshed),
|
||||||
|
):
|
||||||
|
result = await store.get_valid_token()
|
||||||
|
assert result == refreshed
|
||||||
|
|
||||||
|
|
||||||
|
async def test_auth_store_get_valid_returns_none_on_refresh_failure(patched_user_dir):
|
||||||
|
store = HfAuthStore()
|
||||||
|
import time
|
||||||
|
|
||||||
|
expired = Token(
|
||||||
|
access_token="old",
|
||||||
|
refresh_token="rf",
|
||||||
|
expires_at=time.time() - 100,
|
||||||
|
)
|
||||||
|
store.set_token(expired)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.hf_auth.oauth.refresh_access_token",
|
||||||
|
new=AsyncMock(side_effect=RuntimeError("HF down")),
|
||||||
|
):
|
||||||
|
result = await store.get_valid_token()
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Eligibility
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"listen,multi_user,expected",
|
||||||
|
[
|
||||||
|
("127.0.0.1", False, True),
|
||||||
|
("127.0.0.1", True, False), # multi-user disables it
|
||||||
|
("0.0.0.0", False, False), # bind-all is not loopback
|
||||||
|
("0.0.0.0", True, False),
|
||||||
|
("192.168.1.5", False, False), # LAN address
|
||||||
|
("::1", False, True), # IPv6 loopback
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_eligibility(listen, multi_user, expected, monkeypatch):
|
||||||
|
from app.model_downloader.hf_auth import eligibility
|
||||||
|
from comfy.cli_args import args
|
||||||
|
|
||||||
|
monkeypatch.setattr(args, "listen", listen)
|
||||||
|
monkeypatch.setattr(args, "multi_user", multi_user)
|
||||||
|
assert eligibility.is_hf_auth_eligible() is expected
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# gated_detection HF probe
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def test_probe_url_hf_public(fresh_auth_store):
|
||||||
|
"""auth_check succeeds with no token → is_hf_downloadable = True."""
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
|
||||||
|
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
|
||||||
|
with patch("app.model_downloader.gated_detection._auth_check_sync"), patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=1024),
|
||||||
|
):
|
||||||
|
result = await probe_url(url)
|
||||||
|
assert result.is_hf_downloadable is True
|
||||||
|
assert result.file_size == 1024
|
||||||
|
|
||||||
|
|
||||||
|
async def test_probe_url_hf_gated_no_access(fresh_auth_store):
|
||||||
|
"""auth_check raises GatedRepoError → is_hf_downloadable = False."""
|
||||||
|
from huggingface_hub.errors import GatedRepoError
|
||||||
|
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
|
||||||
|
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
|
||||||
|
fake_response = MagicMock(status_code=403)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync",
|
||||||
|
side_effect=GatedRepoError("gated", response=fake_response),
|
||||||
|
), patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
):
|
||||||
|
result = await probe_url(url)
|
||||||
|
assert result.is_hf_downloadable is False
|
||||||
|
|
||||||
|
|
||||||
|
async def test_probe_url_non_hf_skips_auth_check():
|
||||||
|
"""Non-HF URLs never call auth_check; is_hf_downloadable stays None."""
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
|
||||||
|
url = "https://civitai.com/api/download/models/1.safetensors"
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync",
|
||||||
|
) as mocked, patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=2048),
|
||||||
|
):
|
||||||
|
result = await probe_url(url)
|
||||||
|
assert result.is_hf_downloadable is None
|
||||||
|
assert result.file_size == 2048
|
||||||
|
mocked.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_is_gated_cached_across_calls(fresh_auth_store):
|
||||||
|
"""Intrinsic ``is_gated`` should be determined exactly once per URL.
|
||||||
|
|
||||||
|
Subsequent ``probe_url`` calls for the same URL must not re-issue
|
||||||
|
the null-token auth_check — that's the whole point of the cache."""
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
|
||||||
|
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync"
|
||||||
|
) as mocked, patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=1024),
|
||||||
|
):
|
||||||
|
await probe_url(url)
|
||||||
|
await probe_url(url)
|
||||||
|
await probe_url(url)
|
||||||
|
# Three probe_url calls × public-only-needs-1-auth_check = 1 call total.
|
||||||
|
assert mocked.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_file_size_cached_across_calls(fresh_auth_store):
|
||||||
|
"""Once a successful HEAD lands, subsequent calls don't re-HEAD."""
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
|
||||||
|
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync"
|
||||||
|
), patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=2048),
|
||||||
|
) as size_probe:
|
||||||
|
r1 = await probe_url(url)
|
||||||
|
r2 = await probe_url(url)
|
||||||
|
assert r1.file_size == 2048
|
||||||
|
assert r2.file_size == 2048
|
||||||
|
assert size_probe.call_count == 1
|
||||||
|
|
||||||
|
|
||||||
|
async def test_file_size_not_probed_for_gated_no_access(fresh_auth_store):
|
||||||
|
"""When ``is_hf_downloadable`` is False we must NOT HEAD the URL —
|
||||||
|
otherwise a 401-due-to-gating would land as a cached ``None`` that
|
||||||
|
survives a later successful login."""
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
from huggingface_hub.errors import GatedRepoError
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
|
||||||
|
fake_resp = MagicMock(status_code=403)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync",
|
||||||
|
side_effect=GatedRepoError("gated", response=fake_resp),
|
||||||
|
), patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
) as size_probe:
|
||||||
|
result = await probe_url(url)
|
||||||
|
assert result.is_hf_downloadable is False
|
||||||
|
assert result.file_size is None
|
||||||
|
assert size_probe.call_count == 0
|
||||||
|
|
||||||
|
|
||||||
|
async def test_probe_url_passes_token_when_available(fresh_auth_store, patched_user_dir):
|
||||||
|
"""For a gated URL, auth_check runs twice: once with token=None to
|
||||||
|
determine the intrinsic ``is_gated`` flag (cached forever), and once
|
||||||
|
with the stored access_token to determine ``is_hf_downloadable`` for
|
||||||
|
the current user."""
|
||||||
|
from app.model_downloader import gated_detection
|
||||||
|
from app.model_downloader.gated_detection import probe_url
|
||||||
|
from huggingface_hub.errors import GatedRepoError
|
||||||
|
from unittest.mock import MagicMock
|
||||||
|
|
||||||
|
gated_detection.clear_caches_for_tests()
|
||||||
|
fresh_auth_store.set_token(Token(
|
||||||
|
access_token="hf_test_token",
|
||||||
|
refresh_token=None,
|
||||||
|
expires_at=9999999999.0,
|
||||||
|
))
|
||||||
|
url = "https://huggingface.co/private/repo/resolve/main/x.safetensors"
|
||||||
|
|
||||||
|
fake_resp = MagicMock(status_code=403)
|
||||||
|
|
||||||
|
def fake_auth_check(repo_id, token):
|
||||||
|
# Null-token call → repo is gated. Subsequent call with the real
|
||||||
|
# token succeeds (user has access).
|
||||||
|
if token is None:
|
||||||
|
raise GatedRepoError("gated", response=fake_resp)
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.gated_detection._auth_check_sync",
|
||||||
|
side_effect=fake_auth_check,
|
||||||
|
) as mocked, patch(
|
||||||
|
"app.model_downloader.gated_detection._probe_size_once",
|
||||||
|
new=AsyncMock(return_value=None),
|
||||||
|
):
|
||||||
|
result = await probe_url(url)
|
||||||
|
|
||||||
|
# is_hf_downloadable should be True (token-authed call succeeded).
|
||||||
|
assert result.is_hf_downloadable is True
|
||||||
|
# Two calls: (repo_id, None) then (repo_id, <token>).
|
||||||
|
assert mocked.call_count == 2
|
||||||
|
assert mocked.call_args_list[0].args == ("private/repo", None)
|
||||||
|
assert mocked.call_args_list[1].args == ("private/repo", "hf_test_token")
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# OAuth primitives
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_make_pkce_returns_distinct_high_entropy_values():
|
||||||
|
verifier1, challenge1, state1 = oauth._make_pkce()
|
||||||
|
verifier2, challenge2, state2 = oauth._make_pkce()
|
||||||
|
assert verifier1 != verifier2
|
||||||
|
assert challenge1 != challenge2
|
||||||
|
assert state1 != state2
|
||||||
|
# Verifier should be at least 43 chars per PKCE spec.
|
||||||
|
assert len(verifier1) >= 43
|
||||||
|
|
||||||
|
|
||||||
|
def test_build_authorize_url_includes_pkce_and_state():
|
||||||
|
url = oauth._build_authorize_url("challenge123", "state456")
|
||||||
|
assert url.startswith(oauth.AUTHORIZE_URL)
|
||||||
|
assert "client_id=" + oauth.HF_CLIENT_ID in url
|
||||||
|
assert "code_challenge=challenge123" in url
|
||||||
|
assert "code_challenge_method=S256" in url
|
||||||
|
assert "state=state456" in url
|
||||||
|
assert "response_type=code" in url
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Routes
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hf_auth_token_status_empty(aiohttp_client, app):
|
||||||
|
"""No token set → token_available=false, username=null."""
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/api/hf-auth-token-status")
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert data == {"token_available": False, "username": None}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hf_auth_token_status_with_token(
|
||||||
|
aiohttp_client, app, fresh_auth_store, patched_user_dir
|
||||||
|
):
|
||||||
|
"""Token present, whoami works → username is returned."""
|
||||||
|
fresh_auth_store.set_token(Token(
|
||||||
|
access_token="x", refresh_token=None, expires_at=9999999999.0,
|
||||||
|
))
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes._whoami_username",
|
||||||
|
return_value="alice",
|
||||||
|
):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.get("/api/hf-auth-token-status")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert (await resp.json()) == {"token_available": True, "username": "alice"}
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hf_auth_login_start_403_when_ineligible(aiohttp_client, app, monkeypatch):
|
||||||
|
"""Not loopback / multi-user → 403."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||||
|
lambda: False,
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/api/hf-auth-login-start")
|
||||||
|
assert resp.status == 403
|
||||||
|
assert (await resp.json())["error"]["code"] == "HF_AUTH_NOT_ELIGIBLE"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hf_auth_login_start_returns_authorize_url(aiohttp_client, app, monkeypatch):
|
||||||
|
"""Eligible + first attempt → 200 with authorize_url."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||||
|
lambda: True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.start_login_flow",
|
||||||
|
AsyncMock(return_value="https://huggingface.co/oauth/authorize?fake=1"),
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/api/hf-auth-login-start")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert (await resp.json())["authorize_url"].startswith(
|
||||||
|
"https://huggingface.co/oauth/authorize"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hf_auth_login_start_409_when_in_progress(aiohttp_client, app, monkeypatch):
|
||||||
|
"""Lock already held → 409."""
|
||||||
|
from app.model_downloader.hf_auth.oauth import OAuthInProgressError
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||||
|
lambda: True,
|
||||||
|
)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.start_login_flow",
|
||||||
|
AsyncMock(side_effect=OAuthInProgressError()),
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/api/hf-auth-login-start")
|
||||||
|
assert resp.status == 409
|
||||||
|
assert (await resp.json())["error"]["code"] == "HF_AUTH_IN_PROGRESS"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_hf_auth_logout_clears_store(
|
||||||
|
aiohttp_client, app, fresh_auth_store, patched_user_dir
|
||||||
|
):
|
||||||
|
fresh_auth_store.set_token(Token(
|
||||||
|
access_token="x", refresh_token=None, expires_at=9999999999.0,
|
||||||
|
))
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post("/api/hf-auth-logout")
|
||||||
|
assert resp.status == 200
|
||||||
|
assert (await resp.json()) == {"logged_out": True}
|
||||||
|
assert not fresh_auth_store.has_token()
|
||||||
|
|
||||||
|
|
||||||
|
async def test_availability_includes_hf_auth_snapshot(aiohttp_client, app, monkeypatch):
|
||||||
|
"""The availability response embeds {token_available, eligible}."""
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"app.model_downloader.api.routes.is_hf_auth_eligible",
|
||||||
|
lambda: True,
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/models-availability-status",
|
||||||
|
json={"models": {}},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert "hf_auth" in data
|
||||||
|
assert data["hf_auth"] == {"token_available": False, "eligible": True}
|
||||||
504
tests-unit/app_test/model_downloader_test.py
Normal file
504
tests-unit/app_test/model_downloader_test.py
Normal file
@ -0,0 +1,504 @@
|
|||||||
|
"""Unit tests for the server-side model download subsystem.
|
||||||
|
|
||||||
|
Covers the pieces that don't require talking to a real network:
|
||||||
|
|
||||||
|
- path parsing & allowlist (pure functions)
|
||||||
|
- DownloadServer registry lifecycle (in-memory state)
|
||||||
|
- API routes via aiohttp_client + folder_paths/probe_url patches
|
||||||
|
|
||||||
|
Streaming downloads themselves are exercised indirectly — the route-level
|
||||||
|
tests stub out the network probe so we can verify the gating logic in
|
||||||
|
``download_models`` without making real HTTP calls.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import os
|
||||||
|
from unittest.mock import patch, AsyncMock
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from aiohttp import web
|
||||||
|
|
||||||
|
from app.model_downloader.allowlist import is_url_allowed
|
||||||
|
from app.model_downloader.api.routes import register_routes
|
||||||
|
from app.model_downloader.download_server import DownloadServer
|
||||||
|
from app.model_downloader.gated_detection import MetadataProbeResult
|
||||||
|
from app.model_downloader.paths import (
|
||||||
|
InvalidModelId,
|
||||||
|
parse_model_id,
|
||||||
|
resolve_destination,
|
||||||
|
resolve_existing,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Global asyncio mark: the sync tests below trigger a cosmetic
|
||||||
|
# PytestWarning for each one because pytest-asyncio applies the mark
|
||||||
|
# indiscriminately. Other tests in this repo (see custom_node_manager_test.py)
|
||||||
|
# use the same pattern. The warnings are noise, not failures.
|
||||||
|
pytestmark = pytest.mark.asyncio
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Fixtures
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def model_root(tmp_path):
|
||||||
|
"""A fake ``models/`` root with two registered folder types."""
|
||||||
|
loras_dir = tmp_path / "loras"
|
||||||
|
checkpoints_dir = tmp_path / "checkpoints"
|
||||||
|
loras_dir.mkdir()
|
||||||
|
checkpoints_dir.mkdir()
|
||||||
|
return tmp_path, loras_dir, checkpoints_dir
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def patched_folder_paths(model_root):
|
||||||
|
"""Point folder_paths at our fake roots for the duration of one test."""
|
||||||
|
_root, loras_dir, checkpoints_dir = model_root
|
||||||
|
mapping = {
|
||||||
|
"loras": ([str(loras_dir)], {".safetensors"}),
|
||||||
|
"checkpoints": ([str(checkpoints_dir)], {".safetensors"}),
|
||||||
|
}
|
||||||
|
with patch(
|
||||||
|
"folder_paths.folder_names_and_paths", mapping
|
||||||
|
), patch(
|
||||||
|
"folder_paths.get_folder_paths",
|
||||||
|
side_effect=lambda name: mapping.get(name, ([], set()))[0],
|
||||||
|
):
|
||||||
|
yield mapping
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def fresh_download_server():
|
||||||
|
"""Reset the module-level singleton between tests so registry state
|
||||||
|
doesn't leak across tests sharing the singleton."""
|
||||||
|
from app.model_downloader.download_server import DOWNLOAD_SERVER
|
||||||
|
|
||||||
|
DOWNLOAD_SERVER.reset_for_tests()
|
||||||
|
yield DOWNLOAD_SERVER
|
||||||
|
DOWNLOAD_SERVER.reset_for_tests()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def app(patched_folder_paths, fresh_download_server):
|
||||||
|
app = web.Application()
|
||||||
|
register_routes(app)
|
||||||
|
return app
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Pure helpers: allowlist + path parsing
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowlist_accepts_hf_safetensors():
|
||||||
|
assert is_url_allowed("https://huggingface.co/x/y/resolve/main/z.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowlist_accepts_civitai_pth():
|
||||||
|
assert is_url_allowed("https://civitai.com/api/download/models/123.pth")
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowlist_rejects_unknown_host():
|
||||||
|
assert not is_url_allowed("https://example.com/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowlist_rejects_api_path_on_hf():
|
||||||
|
# On an allowlisted host but not pointing at a model file.
|
||||||
|
assert not is_url_allowed("https://huggingface.co/api/models")
|
||||||
|
|
||||||
|
|
||||||
|
def test_allowlist_rejects_non_https_except_localhost():
|
||||||
|
assert not is_url_allowed("http://huggingface.co/x/y.safetensors")
|
||||||
|
assert is_url_allowed("http://localhost:8000/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_model_id_valid(patched_folder_paths):
|
||||||
|
assert parse_model_id("loras/foo.safetensors") == ("loras", "foo.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_model_id_rejects_traversal(patched_folder_paths):
|
||||||
|
with pytest.raises(InvalidModelId):
|
||||||
|
parse_model_id("../etc/passwd")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_model_id_rejects_unknown_folder(patched_folder_paths):
|
||||||
|
with pytest.raises(InvalidModelId):
|
||||||
|
parse_model_id("nope/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_parse_model_id_rejects_double_slash(patched_folder_paths):
|
||||||
|
with pytest.raises(InvalidModelId):
|
||||||
|
parse_model_id("loras/sub/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_existing_returns_path_when_present(model_root, patched_folder_paths):
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
target = loras_dir / "foo.safetensors"
|
||||||
|
target.write_bytes(b"x")
|
||||||
|
assert resolve_existing("loras/foo.safetensors") == str(target)
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_existing_returns_none_when_absent(patched_folder_paths):
|
||||||
|
assert resolve_existing("loras/missing.safetensors") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_resolve_destination_returns_tmp_pair(model_root, patched_folder_paths):
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
final, tmp = resolve_destination("loras/foo.safetensors")
|
||||||
|
assert final == str(loras_dir / "foo.safetensors")
|
||||||
|
assert tmp == final + ".tmp"
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# DownloadServer registry: lifecycle, races, cancellation epoch semantics
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
def test_register_is_exclusive():
|
||||||
|
server = DownloadServer()
|
||||||
|
s1 = server.try_register("loras/x.safetensors", "https://huggingface.co/a")
|
||||||
|
s2 = server.try_register("loras/x.safetensors", "https://huggingface.co/b")
|
||||||
|
assert s1 is not None
|
||||||
|
assert s2 is None
|
||||||
|
assert server.is_downloading("loras/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_removes_session():
|
||||||
|
server = DownloadServer()
|
||||||
|
server.try_register("loras/x.safetensors", "https://huggingface.co/a")
|
||||||
|
assert server.cancel("loras/x.safetensors") is True
|
||||||
|
assert not server.is_downloading("loras/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_cancel_returns_false_when_absent():
|
||||||
|
server = DownloadServer()
|
||||||
|
assert server.cancel("loras/never.safetensors") is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_finish_only_clears_matching_epoch():
|
||||||
|
"""If a session is cancelled and a new one for the same id is
|
||||||
|
registered, ``finish`` from the original worker must not evict the
|
||||||
|
newer session."""
|
||||||
|
server = DownloadServer()
|
||||||
|
s_old = server.try_register("loras/x.safetensors", "u1")
|
||||||
|
server.cancel("loras/x.safetensors")
|
||||||
|
s_new = server.try_register("loras/x.safetensors", "u2")
|
||||||
|
assert s_new is not None and s_new.epoch != s_old.epoch
|
||||||
|
# Old worker's late finish() is a no-op:
|
||||||
|
server.finish(s_old)
|
||||||
|
assert server.is_downloading("loras/x.safetensors")
|
||||||
|
server.finish(s_new)
|
||||||
|
assert not server.is_downloading("loras/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_active_follows_cancellation():
|
||||||
|
server = DownloadServer()
|
||||||
|
s = server.try_register("loras/x.safetensors", "u")
|
||||||
|
assert server.is_active(s)
|
||||||
|
server.cancel("loras/x.safetensors")
|
||||||
|
assert not server.is_active(s)
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_progress_tracks_fraction():
|
||||||
|
server = DownloadServer()
|
||||||
|
s = server.try_register("loras/x.safetensors", "u")
|
||||||
|
server.update_progress(s, 50, 100)
|
||||||
|
snap = server.snapshot()["loras/x.safetensors"]
|
||||||
|
assert snap.bytes_downloaded == 50
|
||||||
|
assert snap.total_bytes == 100
|
||||||
|
assert snap.progress == 0.5
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_progress_with_unknown_total_keeps_progress_none():
|
||||||
|
server = DownloadServer()
|
||||||
|
s = server.try_register("loras/x.safetensors", "u")
|
||||||
|
server.update_progress(s, 50, None)
|
||||||
|
assert server.snapshot()["loras/x.safetensors"].progress is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_cleanup_orphan_tmp_files(model_root):
|
||||||
|
"""Orphan .tmp left by a crashed download must be swept on first use."""
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
orphan = loras_dir / "stale.safetensors.tmp"
|
||||||
|
orphan.write_bytes(b"partial")
|
||||||
|
mapping = {"loras": ([str(loras_dir)], {".safetensors"})}
|
||||||
|
with patch("folder_paths.folder_names_and_paths", mapping), patch(
|
||||||
|
"folder_paths.get_folder_paths",
|
||||||
|
side_effect=lambda name: mapping.get(name, ([], set()))[0],
|
||||||
|
):
|
||||||
|
server = DownloadServer()
|
||||||
|
assert orphan.exists(), "sweep must not run at construction time"
|
||||||
|
server.sweep_orphan_tmp_files()
|
||||||
|
assert not orphan.exists()
|
||||||
|
# Idempotent — a second call is a cheap no-op.
|
||||||
|
server.sweep_orphan_tmp_files()
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Route: POST /api/models-availability-status
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def test_availability_partitions_correctly(
|
||||||
|
aiohttp_client, app, model_root, fresh_download_server
|
||||||
|
):
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
(loras_dir / "present.safetensors").write_bytes(b"x")
|
||||||
|
fresh_download_server.try_register(
|
||||||
|
"loras/inflight.safetensors", "http://localhost:8000/x.safetensors"
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
|
||||||
|
# Stub probes — we're testing state assignment, not network calls.
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes.probe_url",
|
||||||
|
new=AsyncMock(return_value=MetadataProbeResult(
|
||||||
|
file_size=None, is_hf_downloadable=None,
|
||||||
|
)),
|
||||||
|
):
|
||||||
|
body = {
|
||||||
|
"models": {
|
||||||
|
"loras/present.safetensors": "http://localhost:8000/p.safetensors",
|
||||||
|
"loras/missing.safetensors": "http://localhost:8000/m.safetensors",
|
||||||
|
"loras/inflight.safetensors": "http://localhost:8000/x.safetensors",
|
||||||
|
}
|
||||||
|
}
|
||||||
|
resp = await client.post("/api/models-availability-status", json=body)
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
models = data["models"]
|
||||||
|
assert models["loras/present.safetensors"]["state"] == "available"
|
||||||
|
assert models["loras/missing.safetensors"]["state"] == "missing"
|
||||||
|
assert models["loras/inflight.safetensors"]["state"] == "downloading"
|
||||||
|
assert "hf_auth" in data
|
||||||
|
|
||||||
|
|
||||||
|
async def test_availability_invalid_id_classified_as_missing(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes.probe_url",
|
||||||
|
new=AsyncMock(return_value=MetadataProbeResult(
|
||||||
|
file_size=None, is_hf_downloadable=None,
|
||||||
|
)),
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/models-availability-status",
|
||||||
|
json={"models": {"../etc/passwd": "http://localhost:8000/x.safetensors"}},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
data = await resp.json()
|
||||||
|
assert data["models"]["../etc/passwd"]["state"] == "missing"
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Route: POST /api/download-models — precondition gating
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def test_download_rejects_url_not_in_allowlist(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {"loras/x.safetensors": "https://evil.com/x.safetensors"}},
|
||||||
|
)
|
||||||
|
assert resp.status == 400
|
||||||
|
err = (await resp.json())["error"]
|
||||||
|
assert err["code"] == "URL_NOT_ALLOWED"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_download_rejects_already_available(
|
||||||
|
aiohttp_client, app, model_root
|
||||||
|
):
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
(loras_dir / "x.safetensors").write_bytes(b"x")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {
|
||||||
|
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
assert resp.status == 409
|
||||||
|
assert (await resp.json())["error"]["code"] == "ALREADY_AVAILABLE"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_download_rejects_already_downloading(
|
||||||
|
aiohttp_client, app, fresh_download_server
|
||||||
|
):
|
||||||
|
fresh_download_server.try_register(
|
||||||
|
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {
|
||||||
|
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
assert resp.status == 409
|
||||||
|
assert (await resp.json())["error"]["code"] == "ALREADY_DOWNLOADING"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_download_rejects_gated_model(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes.probe_url",
|
||||||
|
new=AsyncMock(return_value=MetadataProbeResult(file_size=None, is_hf_downloadable=False)),
|
||||||
|
):
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {
|
||||||
|
"loras/x.safetensors": "https://huggingface.co/g/r/resolve/main/x.safetensors"
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
assert resp.status == 400
|
||||||
|
assert (await resp.json())["error"]["code"] == "MODEL_NOT_DOWNLOADABLE"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_download_rejects_invalid_model_id(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {"../etc/passwd": "https://huggingface.co/x.safetensors"}},
|
||||||
|
)
|
||||||
|
assert resp.status == 400
|
||||||
|
assert (await resp.json())["error"]["code"] == "INVALID_MODEL_ID"
|
||||||
|
|
||||||
|
|
||||||
|
async def test_download_atomic_failure_does_not_register_partial(
|
||||||
|
aiohttp_client, app, model_root, fresh_download_server
|
||||||
|
):
|
||||||
|
"""If one model in a batch fails, none get registered."""
|
||||||
|
_root, loras_dir, _ = model_root
|
||||||
|
(loras_dir / "already.safetensors").write_bytes(b"x")
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={
|
||||||
|
"models": {
|
||||||
|
"loras/already.safetensors":
|
||||||
|
"https://huggingface.co/a/b/resolve/main/already.safetensors",
|
||||||
|
"loras/new.safetensors":
|
||||||
|
"https://huggingface.co/a/b/resolve/main/new.safetensors",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status == 409
|
||||||
|
# The "new" model should not have been registered as part of the
|
||||||
|
# failed batch.
|
||||||
|
assert not fresh_download_server.is_downloading("loras/new.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_download_schedules_when_all_preconditions_pass(
|
||||||
|
aiohttp_client, app, fresh_download_server
|
||||||
|
):
|
||||||
|
"""Verify the precondition pass, registration pass, and async
|
||||||
|
scheduling all wire up correctly. We patch the streamer to avoid
|
||||||
|
real HTTP while still letting the route execute end-to-end."""
|
||||||
|
started = asyncio.Event()
|
||||||
|
finish_signal = asyncio.Event()
|
||||||
|
|
||||||
|
async def fake_stream(session):
|
||||||
|
started.set()
|
||||||
|
await finish_signal.wait()
|
||||||
|
from app.model_downloader.download_server import DOWNLOAD_SERVER
|
||||||
|
DOWNLOAD_SERVER.finish(session)
|
||||||
|
return "/dev/null"
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes.probe_url",
|
||||||
|
new=AsyncMock(return_value=MetadataProbeResult(file_size=42, is_hf_downloadable=True)),
|
||||||
|
), patch(
|
||||||
|
"app.model_downloader.downloader.stream_to_disk", new=fake_stream
|
||||||
|
):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/download-models",
|
||||||
|
json={"models": {
|
||||||
|
"loras/new.safetensors":
|
||||||
|
"https://huggingface.co/a/b/resolve/main/new.safetensors"
|
||||||
|
}},
|
||||||
|
)
|
||||||
|
assert resp.status == 202
|
||||||
|
body = await resp.json()
|
||||||
|
assert body["accepted"] is True
|
||||||
|
assert body["scheduled"] == ["loras/new.safetensors"]
|
||||||
|
# Wait for the worker to actually start.
|
||||||
|
await asyncio.wait_for(started.wait(), timeout=2.0)
|
||||||
|
assert fresh_download_server.is_downloading("loras/new.safetensors")
|
||||||
|
finish_signal.set()
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Route: POST /api/cancel-model-download-session
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cancel_removes_active_session(
|
||||||
|
aiohttp_client, app, fresh_download_server
|
||||||
|
):
|
||||||
|
fresh_download_server.try_register(
|
||||||
|
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
|
||||||
|
)
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/cancel-model-download-session",
|
||||||
|
json={"model_id": "loras/x.safetensors"},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
assert (await resp.json())["cancelled"] is True
|
||||||
|
assert not fresh_download_server.is_downloading("loras/x.safetensors")
|
||||||
|
|
||||||
|
|
||||||
|
async def test_cancel_returns_404_when_none(aiohttp_client, app):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/cancel-model-download-session",
|
||||||
|
json={"model_id": "loras/nothing.safetensors"},
|
||||||
|
)
|
||||||
|
assert resp.status == 404
|
||||||
|
assert (await resp.json())["error"]["code"] == "NOT_DOWNLOADING"
|
||||||
|
|
||||||
|
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
# Unified availability response embeds metadata per id
|
||||||
|
# --------------------------------------------------------------------------- #
|
||||||
|
|
||||||
|
|
||||||
|
async def test_availability_embeds_metadata(aiohttp_client, app):
|
||||||
|
"""``file_size`` + ``is_hf_downloadable`` come back on the same
|
||||||
|
request as the state — no separate metadata endpoint."""
|
||||||
|
results = {
|
||||||
|
"https://huggingface.co/a/b/resolve/main/free.safetensors":
|
||||||
|
MetadataProbeResult(file_size=1024, is_hf_downloadable=True),
|
||||||
|
"https://huggingface.co/g/r/resolve/main/gated.safetensors":
|
||||||
|
MetadataProbeResult(file_size=None, is_hf_downloadable=False),
|
||||||
|
}
|
||||||
|
|
||||||
|
async def fake_probe(url):
|
||||||
|
return results[url]
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.model_downloader.api.routes.probe_url", new=fake_probe
|
||||||
|
):
|
||||||
|
client = await aiohttp_client(app)
|
||||||
|
resp = await client.post(
|
||||||
|
"/api/models-availability-status",
|
||||||
|
json={
|
||||||
|
"models": {
|
||||||
|
"loras/free.safetensors":
|
||||||
|
"https://huggingface.co/a/b/resolve/main/free.safetensors",
|
||||||
|
"loras/gated.safetensors":
|
||||||
|
"https://huggingface.co/g/r/resolve/main/gated.safetensors",
|
||||||
|
}
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert resp.status == 200
|
||||||
|
models = (await resp.json())["models"]
|
||||||
|
assert models["loras/free.safetensors"]["file_size"] == 1024
|
||||||
|
assert models["loras/free.safetensors"]["is_hf_downloadable"] is True
|
||||||
|
assert models["loras/gated.safetensors"]["file_size"] is None
|
||||||
|
assert models["loras/gated.safetensors"]["is_hf_downloadable"] is False
|
||||||
Loading…
Reference in New Issue
Block a user