test(model_downloader): unit tests for downloads, allowlist, paths, and HF auth

Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
This commit is contained in:
DoronGenzelHass 2026-06-22 12:02:46 +03:00
parent 82ccce7d09
commit dcd61f1132
2 changed files with 1071 additions and 0 deletions

View File

@ -0,0 +1,567 @@
"""Unit tests for the HuggingFace auth subsystem.
Covers:
- token store: save/load roundtrip, chmod 0600, atomic write, delete
- eligibility under various CLI-arg combinations
- URL parsing (huggingface.co host detection + repo_id extraction)
- HF-aware gated_detection.probe_url (mocked auth_check)
- HF auth routes (token status, login start with eligibility gate, logout)
- PKCE primitives + authorize URL shape
The OAuth callback server itself isn't exercised end-to-end here — that
requires a real HF server. We test the components (state checking,
URL building, code-exchange request shape) instead.
"""
from __future__ import annotations
import json
import os
import stat
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from aiohttp import web
from app.model_downloader.api.routes import register_routes
from app.model_downloader.hf_auth import oauth
from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE, HfAuthStore
from app.model_downloader.hf_auth.token_store import (
EXPIRY_BUFFER_SECS,
Token,
delete_token,
load_token,
save_token,
)
from app.model_downloader.hf_url import is_hf_url, repo_id_from_url
pytestmark = pytest.mark.asyncio
# --------------------------------------------------------------------------- #
# Fixtures
# --------------------------------------------------------------------------- #
@pytest.fixture
def patched_user_dir(tmp_path):
"""Redirect ``folder_paths.get_user_directory`` so the token file
lands in an isolated tmp_path instead of the real user dir."""
user_dir = tmp_path / "user"
user_dir.mkdir()
with patch("folder_paths.get_user_directory", return_value=str(user_dir)):
yield user_dir
@pytest.fixture
def fresh_auth_store():
"""Wipe singleton state between tests: auth + probe caches."""
from app.model_downloader import gated_detection
HF_AUTH_STORE._token = None
HF_AUTH_STORE._loaded_from_disk = False
gated_detection.clear_caches_for_tests()
yield HF_AUTH_STORE
HF_AUTH_STORE._token = None
HF_AUTH_STORE._loaded_from_disk = False
gated_detection.clear_caches_for_tests()
@pytest.fixture
def app(patched_user_dir, fresh_auth_store):
app = web.Application()
register_routes(app)
return app
# --------------------------------------------------------------------------- #
# URL parsing
# --------------------------------------------------------------------------- #
def test_is_hf_url_recognises_huggingface_co():
assert is_hf_url("https://huggingface.co/x/y/resolve/main/z.safetensors")
assert is_hf_url("https://huggingface.co/abc")
assert not is_hf_url("https://hf-mirror.com/x/y/resolve/main/z.safetensors")
assert not is_hf_url("https://civitai.com/x.safetensors")
def test_repo_id_from_url_extracts_org_and_repo():
url = "https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-HDR/resolve/main/x.safetensors"
assert repo_id_from_url(url) == "Lightricks/LTX-2.3-22b-IC-LoRA-HDR"
def test_repo_id_from_url_handles_nested_path():
url = "https://huggingface.co/Comfy-Org/ltx-2.3/resolve/main/split_files/loras/x.safetensors"
assert repo_id_from_url(url) == "Comfy-Org/ltx-2.3"
def test_repo_id_from_url_returns_none_for_non_hf():
assert repo_id_from_url("https://civitai.com/x.safetensors") is None
def test_repo_id_from_url_returns_none_for_non_resolve_paths():
assert repo_id_from_url("https://huggingface.co/org/repo/blob/main/x.safetensors") is None
assert repo_id_from_url("https://huggingface.co/org") is None
# --------------------------------------------------------------------------- #
# Token store
# --------------------------------------------------------------------------- #
def test_token_store_roundtrip(patched_user_dir):
tok = Token(
access_token="hf_abc",
refresh_token="rf_def",
expires_at=9999999999.0,
scope="openid profile",
)
save_token(tok)
loaded = load_token()
assert loaded == tok
def test_token_store_writes_0600(patched_user_dir):
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
save_token(tok)
path = os.path.join(patched_user_dir, "hf_auth_token.json")
mode = stat.S_IMODE(os.stat(path).st_mode)
# On Windows we silently no-op chmod; allow either the intended
# mode or whatever umask the OS gave us.
if os.name == "posix":
assert mode == 0o600
def test_token_store_delete_removes_file(patched_user_dir):
tok = Token(access_token="x", refresh_token=None, expires_at=0.0)
save_token(tok)
delete_token()
path = os.path.join(patched_user_dir, "hf_auth_token.json")
assert not os.path.exists(path)
# Idempotent: second delete is fine.
delete_token()
def test_token_store_load_returns_none_for_missing_file(patched_user_dir):
assert load_token() is None
def test_token_store_load_returns_none_for_corrupt_file(patched_user_dir):
path = os.path.join(patched_user_dir, "hf_auth_token.json")
with open(path, "w") as f:
f.write("not json {")
assert load_token() is None
def test_token_is_valid_uses_buffer(patched_user_dir):
import time
fresh = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
nearly_expired = Token(
access_token="x",
refresh_token=None,
expires_at=time.time() + EXPIRY_BUFFER_SECS - 1,
)
assert fresh.is_valid()
assert not nearly_expired.is_valid()
# --------------------------------------------------------------------------- #
# Auth store
# --------------------------------------------------------------------------- #
def test_auth_store_loads_lazily(patched_user_dir):
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
save_token(tok)
store = HfAuthStore()
assert store.has_token()
assert store.get_token_sync() == tok
def test_auth_store_set_persists(patched_user_dir):
store = HfAuthStore()
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
store.set_token(tok)
# Token is on disk now — a fresh store sees it.
assert HfAuthStore().get_token_sync() == tok
def test_auth_store_clear_removes_in_memory_and_on_disk(patched_user_dir):
store = HfAuthStore()
tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0)
store.set_token(tok)
store.clear()
assert not store.has_token()
assert HfAuthStore().get_token_sync() is None
async def test_auth_store_get_valid_returns_fresh_token(patched_user_dir):
store = HfAuthStore()
import time
tok = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600)
store.set_token(tok)
fetched = await store.get_valid_token()
assert fetched == tok
async def test_auth_store_get_valid_refresh_on_expired(patched_user_dir):
store = HfAuthStore()
import time
expired = Token(
access_token="old",
refresh_token="rf",
expires_at=time.time() - 100,
)
store.set_token(expired)
refreshed = Token(
access_token="new",
refresh_token="rf",
expires_at=time.time() + 3600,
)
with patch(
"app.model_downloader.hf_auth.oauth.refresh_access_token",
new=AsyncMock(return_value=refreshed),
):
result = await store.get_valid_token()
assert result == refreshed
async def test_auth_store_get_valid_returns_none_on_refresh_failure(patched_user_dir):
store = HfAuthStore()
import time
expired = Token(
access_token="old",
refresh_token="rf",
expires_at=time.time() - 100,
)
store.set_token(expired)
with patch(
"app.model_downloader.hf_auth.oauth.refresh_access_token",
new=AsyncMock(side_effect=RuntimeError("HF down")),
):
result = await store.get_valid_token()
assert result is None
# --------------------------------------------------------------------------- #
# Eligibility
# --------------------------------------------------------------------------- #
@pytest.mark.parametrize(
"listen,multi_user,expected",
[
("127.0.0.1", False, True),
("127.0.0.1", True, False), # multi-user disables it
("0.0.0.0", False, False), # bind-all is not loopback
("0.0.0.0", True, False),
("192.168.1.5", False, False), # LAN address
("::1", False, True), # IPv6 loopback
],
)
def test_eligibility(listen, multi_user, expected, monkeypatch):
from app.model_downloader.hf_auth import eligibility
from comfy.cli_args import args
monkeypatch.setattr(args, "listen", listen)
monkeypatch.setattr(args, "multi_user", multi_user)
assert eligibility.is_hf_auth_eligible() is expected
# --------------------------------------------------------------------------- #
# gated_detection HF probe
# --------------------------------------------------------------------------- #
async def test_probe_url_hf_public(fresh_auth_store):
"""auth_check succeeds with no token → is_hf_downloadable = True."""
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
with patch("app.model_downloader.gated_detection._auth_check_sync"), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=1024),
):
result = await probe_url(url)
assert result.is_hf_downloadable is True
assert result.file_size == 1024
async def test_probe_url_hf_gated_no_access(fresh_auth_store):
"""auth_check raises GatedRepoError → is_hf_downloadable = False."""
from huggingface_hub.errors import GatedRepoError
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
fake_response = MagicMock(status_code=403)
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
side_effect=GatedRepoError("gated", response=fake_response),
), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=None),
):
result = await probe_url(url)
assert result.is_hf_downloadable is False
async def test_probe_url_non_hf_skips_auth_check():
"""Non-HF URLs never call auth_check; is_hf_downloadable stays None."""
from app.model_downloader.gated_detection import probe_url
url = "https://civitai.com/api/download/models/1.safetensors"
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
) as mocked, patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=2048),
):
result = await probe_url(url)
assert result.is_hf_downloadable is None
assert result.file_size == 2048
mocked.assert_not_called()
async def test_is_gated_cached_across_calls(fresh_auth_store):
"""Intrinsic ``is_gated`` should be determined exactly once per URL.
Subsequent ``probe_url`` calls for the same URL must not re-issue
the null-token auth_check that's the whole point of the cache."""
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
with patch(
"app.model_downloader.gated_detection._auth_check_sync"
) as mocked, patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=1024),
):
await probe_url(url)
await probe_url(url)
await probe_url(url)
# Three probe_url calls × public-only-needs-1-auth_check = 1 call total.
assert mocked.call_count == 1
async def test_file_size_cached_across_calls(fresh_auth_store):
"""Once a successful HEAD lands, subsequent calls don't re-HEAD."""
from app.model_downloader.gated_detection import probe_url
url = "https://huggingface.co/public/repo/resolve/main/x.safetensors"
with patch(
"app.model_downloader.gated_detection._auth_check_sync"
), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=2048),
) as size_probe:
r1 = await probe_url(url)
r2 = await probe_url(url)
assert r1.file_size == 2048
assert r2.file_size == 2048
assert size_probe.call_count == 1
async def test_file_size_not_probed_for_gated_no_access(fresh_auth_store):
"""When ``is_hf_downloadable`` is False we must NOT HEAD the URL —
otherwise a 401-due-to-gating would land as a cached ``None`` that
survives a later successful login."""
from app.model_downloader.gated_detection import probe_url
from huggingface_hub.errors import GatedRepoError
from unittest.mock import MagicMock
url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors"
fake_resp = MagicMock(status_code=403)
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
side_effect=GatedRepoError("gated", response=fake_resp),
), patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=None),
) as size_probe:
result = await probe_url(url)
assert result.is_hf_downloadable is False
assert result.file_size is None
assert size_probe.call_count == 0
async def test_probe_url_passes_token_when_available(fresh_auth_store, patched_user_dir):
"""For a gated URL, auth_check runs twice: once with token=None to
determine the intrinsic ``is_gated`` flag (cached forever), and once
with the stored access_token to determine ``is_hf_downloadable`` for
the current user."""
from app.model_downloader import gated_detection
from app.model_downloader.gated_detection import probe_url
from huggingface_hub.errors import GatedRepoError
from unittest.mock import MagicMock
gated_detection.clear_caches_for_tests()
fresh_auth_store.set_token(Token(
access_token="hf_test_token",
refresh_token=None,
expires_at=9999999999.0,
))
url = "https://huggingface.co/private/repo/resolve/main/x.safetensors"
fake_resp = MagicMock(status_code=403)
def fake_auth_check(repo_id, token):
# Null-token call → repo is gated. Subsequent call with the real
# token succeeds (user has access).
if token is None:
raise GatedRepoError("gated", response=fake_resp)
with patch(
"app.model_downloader.gated_detection._auth_check_sync",
side_effect=fake_auth_check,
) as mocked, patch(
"app.model_downloader.gated_detection._probe_size_once",
new=AsyncMock(return_value=None),
):
result = await probe_url(url)
# is_hf_downloadable should be True (token-authed call succeeded).
assert result.is_hf_downloadable is True
# Two calls: (repo_id, None) then (repo_id, <token>).
assert mocked.call_count == 2
assert mocked.call_args_list[0].args == ("private/repo", None)
assert mocked.call_args_list[1].args == ("private/repo", "hf_test_token")
# --------------------------------------------------------------------------- #
# OAuth primitives
# --------------------------------------------------------------------------- #
def test_make_pkce_returns_distinct_high_entropy_values():
verifier1, challenge1, state1 = oauth._make_pkce()
verifier2, challenge2, state2 = oauth._make_pkce()
assert verifier1 != verifier2
assert challenge1 != challenge2
assert state1 != state2
# Verifier should be at least 43 chars per PKCE spec.
assert len(verifier1) >= 43
def test_build_authorize_url_includes_pkce_and_state():
url = oauth._build_authorize_url("challenge123", "state456")
assert url.startswith(oauth.AUTHORIZE_URL)
assert "client_id=" + oauth.HF_CLIENT_ID in url
assert "code_challenge=challenge123" in url
assert "code_challenge_method=S256" in url
assert "state=state456" in url
assert "response_type=code" in url
# --------------------------------------------------------------------------- #
# Routes
# --------------------------------------------------------------------------- #
async def test_hf_auth_token_status_empty(aiohttp_client, app):
"""No token set → token_available=false, username=null."""
client = await aiohttp_client(app)
resp = await client.get("/api/hf-auth-token-status")
assert resp.status == 200
data = await resp.json()
assert data == {"token_available": False, "username": None}
async def test_hf_auth_token_status_with_token(
aiohttp_client, app, fresh_auth_store, patched_user_dir
):
"""Token present, whoami works → username is returned."""
fresh_auth_store.set_token(Token(
access_token="x", refresh_token=None, expires_at=9999999999.0,
))
with patch(
"app.model_downloader.api.routes._whoami_username",
return_value="alice",
):
client = await aiohttp_client(app)
resp = await client.get("/api/hf-auth-token-status")
assert resp.status == 200
assert (await resp.json()) == {"token_available": True, "username": "alice"}
async def test_hf_auth_login_start_403_when_ineligible(aiohttp_client, app, monkeypatch):
"""Not loopback / multi-user → 403."""
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: False,
)
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-login-start")
assert resp.status == 403
assert (await resp.json())["error"]["code"] == "HF_AUTH_NOT_ELIGIBLE"
async def test_hf_auth_login_start_returns_authorize_url(aiohttp_client, app, monkeypatch):
"""Eligible + first attempt → 200 with authorize_url."""
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: True,
)
monkeypatch.setattr(
"app.model_downloader.api.routes.start_login_flow",
AsyncMock(return_value="https://huggingface.co/oauth/authorize?fake=1"),
)
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-login-start")
assert resp.status == 200
assert (await resp.json())["authorize_url"].startswith(
"https://huggingface.co/oauth/authorize"
)
async def test_hf_auth_login_start_409_when_in_progress(aiohttp_client, app, monkeypatch):
"""Lock already held → 409."""
from app.model_downloader.hf_auth.oauth import OAuthInProgressError
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: True,
)
monkeypatch.setattr(
"app.model_downloader.api.routes.start_login_flow",
AsyncMock(side_effect=OAuthInProgressError()),
)
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-login-start")
assert resp.status == 409
assert (await resp.json())["error"]["code"] == "HF_AUTH_IN_PROGRESS"
async def test_hf_auth_logout_clears_store(
aiohttp_client, app, fresh_auth_store, patched_user_dir
):
fresh_auth_store.set_token(Token(
access_token="x", refresh_token=None, expires_at=9999999999.0,
))
client = await aiohttp_client(app)
resp = await client.post("/api/hf-auth-logout")
assert resp.status == 200
assert (await resp.json()) == {"logged_out": True}
assert not fresh_auth_store.has_token()
async def test_availability_includes_hf_auth_snapshot(aiohttp_client, app, monkeypatch):
"""The availability response embeds {token_available, eligible}."""
monkeypatch.setattr(
"app.model_downloader.api.routes.is_hf_auth_eligible",
lambda: True,
)
client = await aiohttp_client(app)
resp = await client.post(
"/api/models-availability-status",
json={"models": {}},
)
assert resp.status == 200
data = await resp.json()
assert "hf_auth" in data
assert data["hf_auth"] == {"token_available": False, "eligible": True}

View File

@ -0,0 +1,504 @@
"""Unit tests for the server-side model download subsystem.
Covers the pieces that don't require talking to a real network:
- path parsing & allowlist (pure functions)
- DownloadServer registry lifecycle (in-memory state)
- API routes via aiohttp_client + folder_paths/probe_url patches
Streaming downloads themselves are exercised indirectly the route-level
tests stub out the network probe so we can verify the gating logic in
``download_models`` without making real HTTP calls.
"""
from __future__ import annotations
import asyncio
import os
from unittest.mock import patch, AsyncMock
import pytest
from aiohttp import web
from app.model_downloader.allowlist import is_url_allowed
from app.model_downloader.api.routes import register_routes
from app.model_downloader.download_server import DownloadServer
from app.model_downloader.gated_detection import MetadataProbeResult
from app.model_downloader.paths import (
InvalidModelId,
parse_model_id,
resolve_destination,
resolve_existing,
)
# Global asyncio mark: the sync tests below trigger a cosmetic
# PytestWarning for each one because pytest-asyncio applies the mark
# indiscriminately. Other tests in this repo (see custom_node_manager_test.py)
# use the same pattern. The warnings are noise, not failures.
pytestmark = pytest.mark.asyncio
# --------------------------------------------------------------------------- #
# Fixtures
# --------------------------------------------------------------------------- #
@pytest.fixture
def model_root(tmp_path):
"""A fake ``models/`` root with two registered folder types."""
loras_dir = tmp_path / "loras"
checkpoints_dir = tmp_path / "checkpoints"
loras_dir.mkdir()
checkpoints_dir.mkdir()
return tmp_path, loras_dir, checkpoints_dir
@pytest.fixture
def patched_folder_paths(model_root):
"""Point folder_paths at our fake roots for the duration of one test."""
_root, loras_dir, checkpoints_dir = model_root
mapping = {
"loras": ([str(loras_dir)], {".safetensors"}),
"checkpoints": ([str(checkpoints_dir)], {".safetensors"}),
}
with patch(
"folder_paths.folder_names_and_paths", mapping
), patch(
"folder_paths.get_folder_paths",
side_effect=lambda name: mapping.get(name, ([], set()))[0],
):
yield mapping
@pytest.fixture
def fresh_download_server():
"""Reset the module-level singleton between tests so registry state
doesn't leak across tests sharing the singleton."""
from app.model_downloader.download_server import DOWNLOAD_SERVER
DOWNLOAD_SERVER.reset_for_tests()
yield DOWNLOAD_SERVER
DOWNLOAD_SERVER.reset_for_tests()
@pytest.fixture
def app(patched_folder_paths, fresh_download_server):
app = web.Application()
register_routes(app)
return app
# --------------------------------------------------------------------------- #
# Pure helpers: allowlist + path parsing
# --------------------------------------------------------------------------- #
def test_allowlist_accepts_hf_safetensors():
assert is_url_allowed("https://huggingface.co/x/y/resolve/main/z.safetensors")
def test_allowlist_accepts_civitai_pth():
assert is_url_allowed("https://civitai.com/api/download/models/123.pth")
def test_allowlist_rejects_unknown_host():
assert not is_url_allowed("https://example.com/x.safetensors")
def test_allowlist_rejects_api_path_on_hf():
# On an allowlisted host but not pointing at a model file.
assert not is_url_allowed("https://huggingface.co/api/models")
def test_allowlist_rejects_non_https_except_localhost():
assert not is_url_allowed("http://huggingface.co/x/y.safetensors")
assert is_url_allowed("http://localhost:8000/x.safetensors")
def test_parse_model_id_valid(patched_folder_paths):
assert parse_model_id("loras/foo.safetensors") == ("loras", "foo.safetensors")
def test_parse_model_id_rejects_traversal(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("../etc/passwd")
def test_parse_model_id_rejects_unknown_folder(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("nope/x.safetensors")
def test_parse_model_id_rejects_double_slash(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("loras/sub/x.safetensors")
def test_resolve_existing_returns_path_when_present(model_root, patched_folder_paths):
_root, loras_dir, _ = model_root
target = loras_dir / "foo.safetensors"
target.write_bytes(b"x")
assert resolve_existing("loras/foo.safetensors") == str(target)
def test_resolve_existing_returns_none_when_absent(patched_folder_paths):
assert resolve_existing("loras/missing.safetensors") is None
def test_resolve_destination_returns_tmp_pair(model_root, patched_folder_paths):
_root, loras_dir, _ = model_root
final, tmp = resolve_destination("loras/foo.safetensors")
assert final == str(loras_dir / "foo.safetensors")
assert tmp == final + ".tmp"
# --------------------------------------------------------------------------- #
# DownloadServer registry: lifecycle, races, cancellation epoch semantics
# --------------------------------------------------------------------------- #
def test_register_is_exclusive():
server = DownloadServer()
s1 = server.try_register("loras/x.safetensors", "https://huggingface.co/a")
s2 = server.try_register("loras/x.safetensors", "https://huggingface.co/b")
assert s1 is not None
assert s2 is None
assert server.is_downloading("loras/x.safetensors")
def test_cancel_removes_session():
server = DownloadServer()
server.try_register("loras/x.safetensors", "https://huggingface.co/a")
assert server.cancel("loras/x.safetensors") is True
assert not server.is_downloading("loras/x.safetensors")
def test_cancel_returns_false_when_absent():
server = DownloadServer()
assert server.cancel("loras/never.safetensors") is False
def test_finish_only_clears_matching_epoch():
"""If a session is cancelled and a new one for the same id is
registered, ``finish`` from the original worker must not evict the
newer session."""
server = DownloadServer()
s_old = server.try_register("loras/x.safetensors", "u1")
server.cancel("loras/x.safetensors")
s_new = server.try_register("loras/x.safetensors", "u2")
assert s_new is not None and s_new.epoch != s_old.epoch
# Old worker's late finish() is a no-op:
server.finish(s_old)
assert server.is_downloading("loras/x.safetensors")
server.finish(s_new)
assert not server.is_downloading("loras/x.safetensors")
def test_is_active_follows_cancellation():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
assert server.is_active(s)
server.cancel("loras/x.safetensors")
assert not server.is_active(s)
def test_update_progress_tracks_fraction():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
server.update_progress(s, 50, 100)
snap = server.snapshot()["loras/x.safetensors"]
assert snap.bytes_downloaded == 50
assert snap.total_bytes == 100
assert snap.progress == 0.5
def test_update_progress_with_unknown_total_keeps_progress_none():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
server.update_progress(s, 50, None)
assert server.snapshot()["loras/x.safetensors"].progress is None
def test_cleanup_orphan_tmp_files(model_root):
"""Orphan .tmp left by a crashed download must be swept on first use."""
_root, loras_dir, _ = model_root
orphan = loras_dir / "stale.safetensors.tmp"
orphan.write_bytes(b"partial")
mapping = {"loras": ([str(loras_dir)], {".safetensors"})}
with patch("folder_paths.folder_names_and_paths", mapping), patch(
"folder_paths.get_folder_paths",
side_effect=lambda name: mapping.get(name, ([], set()))[0],
):
server = DownloadServer()
assert orphan.exists(), "sweep must not run at construction time"
server.sweep_orphan_tmp_files()
assert not orphan.exists()
# Idempotent — a second call is a cheap no-op.
server.sweep_orphan_tmp_files()
# --------------------------------------------------------------------------- #
# Route: POST /api/models-availability-status
# --------------------------------------------------------------------------- #
async def test_availability_partitions_correctly(
aiohttp_client, app, model_root, fresh_download_server
):
_root, loras_dir, _ = model_root
(loras_dir / "present.safetensors").write_bytes(b"x")
fresh_download_server.try_register(
"loras/inflight.safetensors", "http://localhost:8000/x.safetensors"
)
client = await aiohttp_client(app)
# Stub probes — we're testing state assignment, not network calls.
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(
file_size=None, is_hf_downloadable=None,
)),
):
body = {
"models": {
"loras/present.safetensors": "http://localhost:8000/p.safetensors",
"loras/missing.safetensors": "http://localhost:8000/m.safetensors",
"loras/inflight.safetensors": "http://localhost:8000/x.safetensors",
}
}
resp = await client.post("/api/models-availability-status", json=body)
assert resp.status == 200
data = await resp.json()
models = data["models"]
assert models["loras/present.safetensors"]["state"] == "available"
assert models["loras/missing.safetensors"]["state"] == "missing"
assert models["loras/inflight.safetensors"]["state"] == "downloading"
assert "hf_auth" in data
async def test_availability_invalid_id_classified_as_missing(aiohttp_client, app):
client = await aiohttp_client(app)
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(
file_size=None, is_hf_downloadable=None,
)),
):
resp = await client.post(
"/api/models-availability-status",
json={"models": {"../etc/passwd": "http://localhost:8000/x.safetensors"}},
)
assert resp.status == 200
data = await resp.json()
assert data["models"]["../etc/passwd"]["state"] == "missing"
# --------------------------------------------------------------------------- #
# Route: POST /api/download-models — precondition gating
# --------------------------------------------------------------------------- #
async def test_download_rejects_url_not_in_allowlist(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {"loras/x.safetensors": "https://evil.com/x.safetensors"}},
)
assert resp.status == 400
err = (await resp.json())["error"]
assert err["code"] == "URL_NOT_ALLOWED"
async def test_download_rejects_already_available(
aiohttp_client, app, model_root
):
_root, loras_dir, _ = model_root
(loras_dir / "x.safetensors").write_bytes(b"x")
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
}},
)
assert resp.status == 409
assert (await resp.json())["error"]["code"] == "ALREADY_AVAILABLE"
async def test_download_rejects_already_downloading(
aiohttp_client, app, fresh_download_server
):
fresh_download_server.try_register(
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
)
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
}},
)
assert resp.status == 409
assert (await resp.json())["error"]["code"] == "ALREADY_DOWNLOADING"
async def test_download_rejects_gated_model(aiohttp_client, app):
client = await aiohttp_client(app)
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(file_size=None, is_hf_downloadable=False)),
):
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/g/r/resolve/main/x.safetensors"
}},
)
assert resp.status == 400
assert (await resp.json())["error"]["code"] == "MODEL_NOT_DOWNLOADABLE"
async def test_download_rejects_invalid_model_id(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {"../etc/passwd": "https://huggingface.co/x.safetensors"}},
)
assert resp.status == 400
assert (await resp.json())["error"]["code"] == "INVALID_MODEL_ID"
async def test_download_atomic_failure_does_not_register_partial(
aiohttp_client, app, model_root, fresh_download_server
):
"""If one model in a batch fails, none get registered."""
_root, loras_dir, _ = model_root
(loras_dir / "already.safetensors").write_bytes(b"x")
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={
"models": {
"loras/already.safetensors":
"https://huggingface.co/a/b/resolve/main/already.safetensors",
"loras/new.safetensors":
"https://huggingface.co/a/b/resolve/main/new.safetensors",
}
},
)
assert resp.status == 409
# The "new" model should not have been registered as part of the
# failed batch.
assert not fresh_download_server.is_downloading("loras/new.safetensors")
async def test_download_schedules_when_all_preconditions_pass(
aiohttp_client, app, fresh_download_server
):
"""Verify the precondition pass, registration pass, and async
scheduling all wire up correctly. We patch the streamer to avoid
real HTTP while still letting the route execute end-to-end."""
started = asyncio.Event()
finish_signal = asyncio.Event()
async def fake_stream(session):
started.set()
await finish_signal.wait()
from app.model_downloader.download_server import DOWNLOAD_SERVER
DOWNLOAD_SERVER.finish(session)
return "/dev/null"
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(file_size=42, is_hf_downloadable=True)),
), patch(
"app.model_downloader.downloader.stream_to_disk", new=fake_stream
):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/new.safetensors":
"https://huggingface.co/a/b/resolve/main/new.safetensors"
}},
)
assert resp.status == 202
body = await resp.json()
assert body["accepted"] is True
assert body["scheduled"] == ["loras/new.safetensors"]
# Wait for the worker to actually start.
await asyncio.wait_for(started.wait(), timeout=2.0)
assert fresh_download_server.is_downloading("loras/new.safetensors")
finish_signal.set()
# --------------------------------------------------------------------------- #
# Route: POST /api/cancel-model-download-session
# --------------------------------------------------------------------------- #
async def test_cancel_removes_active_session(
aiohttp_client, app, fresh_download_server
):
fresh_download_server.try_register(
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
)
client = await aiohttp_client(app)
resp = await client.post(
"/api/cancel-model-download-session",
json={"model_id": "loras/x.safetensors"},
)
assert resp.status == 200
assert (await resp.json())["cancelled"] is True
assert not fresh_download_server.is_downloading("loras/x.safetensors")
async def test_cancel_returns_404_when_none(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/cancel-model-download-session",
json={"model_id": "loras/nothing.safetensors"},
)
assert resp.status == 404
assert (await resp.json())["error"]["code"] == "NOT_DOWNLOADING"
# --------------------------------------------------------------------------- #
# Unified availability response embeds metadata per id
# --------------------------------------------------------------------------- #
async def test_availability_embeds_metadata(aiohttp_client, app):
"""``file_size`` + ``is_hf_downloadable`` come back on the same
request as the state no separate metadata endpoint."""
results = {
"https://huggingface.co/a/b/resolve/main/free.safetensors":
MetadataProbeResult(file_size=1024, is_hf_downloadable=True),
"https://huggingface.co/g/r/resolve/main/gated.safetensors":
MetadataProbeResult(file_size=None, is_hf_downloadable=False),
}
async def fake_probe(url):
return results[url]
with patch(
"app.model_downloader.api.routes.probe_url", new=fake_probe
):
client = await aiohttp_client(app)
resp = await client.post(
"/api/models-availability-status",
json={
"models": {
"loras/free.safetensors":
"https://huggingface.co/a/b/resolve/main/free.safetensors",
"loras/gated.safetensors":
"https://huggingface.co/g/r/resolve/main/gated.safetensors",
}
},
)
assert resp.status == 200
models = (await resp.json())["models"]
assert models["loras/free.safetensors"]["file_size"] == 1024
assert models["loras/free.safetensors"]["is_hf_downloadable"] is True
assert models["loras/gated.safetensors"]["file_size"] is None
assert models["loras/gated.safetensors"]["is_hf_downloadable"] is False