diff --git a/tests-unit/app_test/hf_auth_test.py b/tests-unit/app_test/hf_auth_test.py new file mode 100644 index 000000000..39d0d6ee7 --- /dev/null +++ b/tests-unit/app_test/hf_auth_test.py @@ -0,0 +1,567 @@ +"""Unit tests for the HuggingFace auth subsystem. + +Covers: + - token store: save/load roundtrip, chmod 0600, atomic write, delete + - eligibility under various CLI-arg combinations + - URL parsing (huggingface.co host detection + repo_id extraction) + - HF-aware gated_detection.probe_url (mocked auth_check) + - HF auth routes (token status, login start with eligibility gate, logout) + - PKCE primitives + authorize URL shape + +The OAuth callback server itself isn't exercised end-to-end here — that +requires a real HF server. We test the components (state checking, +URL building, code-exchange request shape) instead. +""" + +from __future__ import annotations + +import json +import os +import stat +from unittest.mock import AsyncMock, MagicMock, patch + +import pytest +from aiohttp import web + +from app.model_downloader.api.routes import register_routes +from app.model_downloader.hf_auth import oauth +from app.model_downloader.hf_auth.auth_store import HF_AUTH_STORE, HfAuthStore +from app.model_downloader.hf_auth.token_store import ( + EXPIRY_BUFFER_SECS, + Token, + delete_token, + load_token, + save_token, +) +from app.model_downloader.hf_url import is_hf_url, repo_id_from_url + +pytestmark = pytest.mark.asyncio + + +# --------------------------------------------------------------------------- # +# Fixtures +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def patched_user_dir(tmp_path): + """Redirect ``folder_paths.get_user_directory`` so the token file + lands in an isolated tmp_path instead of the real user dir.""" + user_dir = tmp_path / "user" + user_dir.mkdir() + with patch("folder_paths.get_user_directory", return_value=str(user_dir)): + yield user_dir + + +@pytest.fixture +def fresh_auth_store(): + """Wipe singleton state between tests: auth + probe caches.""" + from app.model_downloader import gated_detection + + HF_AUTH_STORE._token = None + HF_AUTH_STORE._loaded_from_disk = False + gated_detection.clear_caches_for_tests() + yield HF_AUTH_STORE + HF_AUTH_STORE._token = None + HF_AUTH_STORE._loaded_from_disk = False + gated_detection.clear_caches_for_tests() + + +@pytest.fixture +def app(patched_user_dir, fresh_auth_store): + app = web.Application() + register_routes(app) + return app + + +# --------------------------------------------------------------------------- # +# URL parsing +# --------------------------------------------------------------------------- # + + +def test_is_hf_url_recognises_huggingface_co(): + assert is_hf_url("https://huggingface.co/x/y/resolve/main/z.safetensors") + assert is_hf_url("https://huggingface.co/abc") + assert not is_hf_url("https://hf-mirror.com/x/y/resolve/main/z.safetensors") + assert not is_hf_url("https://civitai.com/x.safetensors") + + +def test_repo_id_from_url_extracts_org_and_repo(): + url = "https://huggingface.co/Lightricks/LTX-2.3-22b-IC-LoRA-HDR/resolve/main/x.safetensors" + assert repo_id_from_url(url) == "Lightricks/LTX-2.3-22b-IC-LoRA-HDR" + + +def test_repo_id_from_url_handles_nested_path(): + url = "https://huggingface.co/Comfy-Org/ltx-2.3/resolve/main/split_files/loras/x.safetensors" + assert repo_id_from_url(url) == "Comfy-Org/ltx-2.3" + + +def test_repo_id_from_url_returns_none_for_non_hf(): + assert repo_id_from_url("https://civitai.com/x.safetensors") is None + + +def test_repo_id_from_url_returns_none_for_non_resolve_paths(): + assert repo_id_from_url("https://huggingface.co/org/repo/blob/main/x.safetensors") is None + assert repo_id_from_url("https://huggingface.co/org") is None + + +# --------------------------------------------------------------------------- # +# Token store +# --------------------------------------------------------------------------- # + + +def test_token_store_roundtrip(patched_user_dir): + tok = Token( + access_token="hf_abc", + refresh_token="rf_def", + expires_at=9999999999.0, + scope="openid profile", + ) + save_token(tok) + loaded = load_token() + assert loaded == tok + + +def test_token_store_writes_0600(patched_user_dir): + tok = Token(access_token="x", refresh_token=None, expires_at=0.0) + save_token(tok) + path = os.path.join(patched_user_dir, "hf_auth_token.json") + mode = stat.S_IMODE(os.stat(path).st_mode) + # On Windows we silently no-op chmod; allow either the intended + # mode or whatever umask the OS gave us. + if os.name == "posix": + assert mode == 0o600 + + +def test_token_store_delete_removes_file(patched_user_dir): + tok = Token(access_token="x", refresh_token=None, expires_at=0.0) + save_token(tok) + delete_token() + path = os.path.join(patched_user_dir, "hf_auth_token.json") + assert not os.path.exists(path) + # Idempotent: second delete is fine. + delete_token() + + +def test_token_store_load_returns_none_for_missing_file(patched_user_dir): + assert load_token() is None + + +def test_token_store_load_returns_none_for_corrupt_file(patched_user_dir): + path = os.path.join(patched_user_dir, "hf_auth_token.json") + with open(path, "w") as f: + f.write("not json {") + assert load_token() is None + + +def test_token_is_valid_uses_buffer(patched_user_dir): + import time + + fresh = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600) + nearly_expired = Token( + access_token="x", + refresh_token=None, + expires_at=time.time() + EXPIRY_BUFFER_SECS - 1, + ) + assert fresh.is_valid() + assert not nearly_expired.is_valid() + + +# --------------------------------------------------------------------------- # +# Auth store +# --------------------------------------------------------------------------- # + + +def test_auth_store_loads_lazily(patched_user_dir): + tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0) + save_token(tok) + store = HfAuthStore() + assert store.has_token() + assert store.get_token_sync() == tok + + +def test_auth_store_set_persists(patched_user_dir): + store = HfAuthStore() + tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0) + store.set_token(tok) + # Token is on disk now — a fresh store sees it. + assert HfAuthStore().get_token_sync() == tok + + +def test_auth_store_clear_removes_in_memory_and_on_disk(patched_user_dir): + store = HfAuthStore() + tok = Token(access_token="x", refresh_token=None, expires_at=9999999999.0) + store.set_token(tok) + store.clear() + assert not store.has_token() + assert HfAuthStore().get_token_sync() is None + + +async def test_auth_store_get_valid_returns_fresh_token(patched_user_dir): + store = HfAuthStore() + import time + + tok = Token(access_token="x", refresh_token=None, expires_at=time.time() + 3600) + store.set_token(tok) + fetched = await store.get_valid_token() + assert fetched == tok + + +async def test_auth_store_get_valid_refresh_on_expired(patched_user_dir): + store = HfAuthStore() + import time + + expired = Token( + access_token="old", + refresh_token="rf", + expires_at=time.time() - 100, + ) + store.set_token(expired) + refreshed = Token( + access_token="new", + refresh_token="rf", + expires_at=time.time() + 3600, + ) + with patch( + "app.model_downloader.hf_auth.oauth.refresh_access_token", + new=AsyncMock(return_value=refreshed), + ): + result = await store.get_valid_token() + assert result == refreshed + + +async def test_auth_store_get_valid_returns_none_on_refresh_failure(patched_user_dir): + store = HfAuthStore() + import time + + expired = Token( + access_token="old", + refresh_token="rf", + expires_at=time.time() - 100, + ) + store.set_token(expired) + with patch( + "app.model_downloader.hf_auth.oauth.refresh_access_token", + new=AsyncMock(side_effect=RuntimeError("HF down")), + ): + result = await store.get_valid_token() + assert result is None + + +# --------------------------------------------------------------------------- # +# Eligibility +# --------------------------------------------------------------------------- # + + +@pytest.mark.parametrize( + "listen,multi_user,expected", + [ + ("127.0.0.1", False, True), + ("127.0.0.1", True, False), # multi-user disables it + ("0.0.0.0", False, False), # bind-all is not loopback + ("0.0.0.0", True, False), + ("192.168.1.5", False, False), # LAN address + ("::1", False, True), # IPv6 loopback + ], +) +def test_eligibility(listen, multi_user, expected, monkeypatch): + from app.model_downloader.hf_auth import eligibility + from comfy.cli_args import args + + monkeypatch.setattr(args, "listen", listen) + monkeypatch.setattr(args, "multi_user", multi_user) + assert eligibility.is_hf_auth_eligible() is expected + + +# --------------------------------------------------------------------------- # +# gated_detection HF probe +# --------------------------------------------------------------------------- # + + +async def test_probe_url_hf_public(fresh_auth_store): + """auth_check succeeds with no token → is_hf_downloadable = True.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/public/repo/resolve/main/x.safetensors" + with patch("app.model_downloader.gated_detection._auth_check_sync"), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=1024), + ): + result = await probe_url(url) + assert result.is_hf_downloadable is True + assert result.file_size == 1024 + + +async def test_probe_url_hf_gated_no_access(fresh_auth_store): + """auth_check raises GatedRepoError → is_hf_downloadable = False.""" + from huggingface_hub.errors import GatedRepoError + + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors" + fake_response = MagicMock(status_code=403) + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + side_effect=GatedRepoError("gated", response=fake_response), + ), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=None), + ): + result = await probe_url(url) + assert result.is_hf_downloadable is False + + +async def test_probe_url_non_hf_skips_auth_check(): + """Non-HF URLs never call auth_check; is_hf_downloadable stays None.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://civitai.com/api/download/models/1.safetensors" + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + ) as mocked, patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=2048), + ): + result = await probe_url(url) + assert result.is_hf_downloadable is None + assert result.file_size == 2048 + mocked.assert_not_called() + + +async def test_is_gated_cached_across_calls(fresh_auth_store): + """Intrinsic ``is_gated`` should be determined exactly once per URL. + + Subsequent ``probe_url`` calls for the same URL must not re-issue + the null-token auth_check — that's the whole point of the cache.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/public/repo/resolve/main/x.safetensors" + with patch( + "app.model_downloader.gated_detection._auth_check_sync" + ) as mocked, patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=1024), + ): + await probe_url(url) + await probe_url(url) + await probe_url(url) + # Three probe_url calls × public-only-needs-1-auth_check = 1 call total. + assert mocked.call_count == 1 + + +async def test_file_size_cached_across_calls(fresh_auth_store): + """Once a successful HEAD lands, subsequent calls don't re-HEAD.""" + from app.model_downloader.gated_detection import probe_url + + url = "https://huggingface.co/public/repo/resolve/main/x.safetensors" + with patch( + "app.model_downloader.gated_detection._auth_check_sync" + ), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=2048), + ) as size_probe: + r1 = await probe_url(url) + r2 = await probe_url(url) + assert r1.file_size == 2048 + assert r2.file_size == 2048 + assert size_probe.call_count == 1 + + +async def test_file_size_not_probed_for_gated_no_access(fresh_auth_store): + """When ``is_hf_downloadable`` is False we must NOT HEAD the URL — + otherwise a 401-due-to-gating would land as a cached ``None`` that + survives a later successful login.""" + from app.model_downloader.gated_detection import probe_url + from huggingface_hub.errors import GatedRepoError + from unittest.mock import MagicMock + + url = "https://huggingface.co/gated/repo/resolve/main/x.safetensors" + fake_resp = MagicMock(status_code=403) + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + side_effect=GatedRepoError("gated", response=fake_resp), + ), patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=None), + ) as size_probe: + result = await probe_url(url) + assert result.is_hf_downloadable is False + assert result.file_size is None + assert size_probe.call_count == 0 + + +async def test_probe_url_passes_token_when_available(fresh_auth_store, patched_user_dir): + """For a gated URL, auth_check runs twice: once with token=None to + determine the intrinsic ``is_gated`` flag (cached forever), and once + with the stored access_token to determine ``is_hf_downloadable`` for + the current user.""" + from app.model_downloader import gated_detection + from app.model_downloader.gated_detection import probe_url + from huggingface_hub.errors import GatedRepoError + from unittest.mock import MagicMock + + gated_detection.clear_caches_for_tests() + fresh_auth_store.set_token(Token( + access_token="hf_test_token", + refresh_token=None, + expires_at=9999999999.0, + )) + url = "https://huggingface.co/private/repo/resolve/main/x.safetensors" + + fake_resp = MagicMock(status_code=403) + + def fake_auth_check(repo_id, token): + # Null-token call → repo is gated. Subsequent call with the real + # token succeeds (user has access). + if token is None: + raise GatedRepoError("gated", response=fake_resp) + + with patch( + "app.model_downloader.gated_detection._auth_check_sync", + side_effect=fake_auth_check, + ) as mocked, patch( + "app.model_downloader.gated_detection._probe_size_once", + new=AsyncMock(return_value=None), + ): + result = await probe_url(url) + + # is_hf_downloadable should be True (token-authed call succeeded). + assert result.is_hf_downloadable is True + # Two calls: (repo_id, None) then (repo_id, ). + assert mocked.call_count == 2 + assert mocked.call_args_list[0].args == ("private/repo", None) + assert mocked.call_args_list[1].args == ("private/repo", "hf_test_token") + + +# --------------------------------------------------------------------------- # +# OAuth primitives +# --------------------------------------------------------------------------- # + + +def test_make_pkce_returns_distinct_high_entropy_values(): + verifier1, challenge1, state1 = oauth._make_pkce() + verifier2, challenge2, state2 = oauth._make_pkce() + assert verifier1 != verifier2 + assert challenge1 != challenge2 + assert state1 != state2 + # Verifier should be at least 43 chars per PKCE spec. + assert len(verifier1) >= 43 + + +def test_build_authorize_url_includes_pkce_and_state(): + url = oauth._build_authorize_url("challenge123", "state456") + assert url.startswith(oauth.AUTHORIZE_URL) + assert "client_id=" + oauth.HF_CLIENT_ID in url + assert "code_challenge=challenge123" in url + assert "code_challenge_method=S256" in url + assert "state=state456" in url + assert "response_type=code" in url + + +# --------------------------------------------------------------------------- # +# Routes +# --------------------------------------------------------------------------- # + + +async def test_hf_auth_token_status_empty(aiohttp_client, app): + """No token set → token_available=false, username=null.""" + client = await aiohttp_client(app) + resp = await client.get("/api/hf-auth-token-status") + assert resp.status == 200 + data = await resp.json() + assert data == {"token_available": False, "username": None} + + +async def test_hf_auth_token_status_with_token( + aiohttp_client, app, fresh_auth_store, patched_user_dir +): + """Token present, whoami works → username is returned.""" + fresh_auth_store.set_token(Token( + access_token="x", refresh_token=None, expires_at=9999999999.0, + )) + with patch( + "app.model_downloader.api.routes._whoami_username", + return_value="alice", + ): + client = await aiohttp_client(app) + resp = await client.get("/api/hf-auth-token-status") + assert resp.status == 200 + assert (await resp.json()) == {"token_available": True, "username": "alice"} + + +async def test_hf_auth_login_start_403_when_ineligible(aiohttp_client, app, monkeypatch): + """Not loopback / multi-user → 403.""" + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: False, + ) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-login-start") + assert resp.status == 403 + assert (await resp.json())["error"]["code"] == "HF_AUTH_NOT_ELIGIBLE" + + +async def test_hf_auth_login_start_returns_authorize_url(aiohttp_client, app, monkeypatch): + """Eligible + first attempt → 200 with authorize_url.""" + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: True, + ) + monkeypatch.setattr( + "app.model_downloader.api.routes.start_login_flow", + AsyncMock(return_value="https://huggingface.co/oauth/authorize?fake=1"), + ) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-login-start") + assert resp.status == 200 + assert (await resp.json())["authorize_url"].startswith( + "https://huggingface.co/oauth/authorize" + ) + + +async def test_hf_auth_login_start_409_when_in_progress(aiohttp_client, app, monkeypatch): + """Lock already held → 409.""" + from app.model_downloader.hf_auth.oauth import OAuthInProgressError + + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: True, + ) + monkeypatch.setattr( + "app.model_downloader.api.routes.start_login_flow", + AsyncMock(side_effect=OAuthInProgressError()), + ) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-login-start") + assert resp.status == 409 + assert (await resp.json())["error"]["code"] == "HF_AUTH_IN_PROGRESS" + + +async def test_hf_auth_logout_clears_store( + aiohttp_client, app, fresh_auth_store, patched_user_dir +): + fresh_auth_store.set_token(Token( + access_token="x", refresh_token=None, expires_at=9999999999.0, + )) + client = await aiohttp_client(app) + resp = await client.post("/api/hf-auth-logout") + assert resp.status == 200 + assert (await resp.json()) == {"logged_out": True} + assert not fresh_auth_store.has_token() + + +async def test_availability_includes_hf_auth_snapshot(aiohttp_client, app, monkeypatch): + """The availability response embeds {token_available, eligible}.""" + monkeypatch.setattr( + "app.model_downloader.api.routes.is_hf_auth_eligible", + lambda: True, + ) + client = await aiohttp_client(app) + resp = await client.post( + "/api/models-availability-status", + json={"models": {}}, + ) + assert resp.status == 200 + data = await resp.json() + assert "hf_auth" in data + assert data["hf_auth"] == {"token_available": False, "eligible": True} diff --git a/tests-unit/app_test/model_downloader_test.py b/tests-unit/app_test/model_downloader_test.py new file mode 100644 index 000000000..581700793 --- /dev/null +++ b/tests-unit/app_test/model_downloader_test.py @@ -0,0 +1,504 @@ +"""Unit tests for the server-side model download subsystem. + +Covers the pieces that don't require talking to a real network: + + - path parsing & allowlist (pure functions) + - DownloadServer registry lifecycle (in-memory state) + - API routes via aiohttp_client + folder_paths/probe_url patches + +Streaming downloads themselves are exercised indirectly — the route-level +tests stub out the network probe so we can verify the gating logic in +``download_models`` without making real HTTP calls. +""" + +from __future__ import annotations + +import asyncio +import os +from unittest.mock import patch, AsyncMock + +import pytest +from aiohttp import web + +from app.model_downloader.allowlist import is_url_allowed +from app.model_downloader.api.routes import register_routes +from app.model_downloader.download_server import DownloadServer +from app.model_downloader.gated_detection import MetadataProbeResult +from app.model_downloader.paths import ( + InvalidModelId, + parse_model_id, + resolve_destination, + resolve_existing, +) + +# Global asyncio mark: the sync tests below trigger a cosmetic +# PytestWarning for each one because pytest-asyncio applies the mark +# indiscriminately. Other tests in this repo (see custom_node_manager_test.py) +# use the same pattern. The warnings are noise, not failures. +pytestmark = pytest.mark.asyncio + + +# --------------------------------------------------------------------------- # +# Fixtures +# --------------------------------------------------------------------------- # + + +@pytest.fixture +def model_root(tmp_path): + """A fake ``models/`` root with two registered folder types.""" + loras_dir = tmp_path / "loras" + checkpoints_dir = tmp_path / "checkpoints" + loras_dir.mkdir() + checkpoints_dir.mkdir() + return tmp_path, loras_dir, checkpoints_dir + + +@pytest.fixture +def patched_folder_paths(model_root): + """Point folder_paths at our fake roots for the duration of one test.""" + _root, loras_dir, checkpoints_dir = model_root + mapping = { + "loras": ([str(loras_dir)], {".safetensors"}), + "checkpoints": ([str(checkpoints_dir)], {".safetensors"}), + } + with patch( + "folder_paths.folder_names_and_paths", mapping + ), patch( + "folder_paths.get_folder_paths", + side_effect=lambda name: mapping.get(name, ([], set()))[0], + ): + yield mapping + + +@pytest.fixture +def fresh_download_server(): + """Reset the module-level singleton between tests so registry state + doesn't leak across tests sharing the singleton.""" + from app.model_downloader.download_server import DOWNLOAD_SERVER + + DOWNLOAD_SERVER.reset_for_tests() + yield DOWNLOAD_SERVER + DOWNLOAD_SERVER.reset_for_tests() + + +@pytest.fixture +def app(patched_folder_paths, fresh_download_server): + app = web.Application() + register_routes(app) + return app + + +# --------------------------------------------------------------------------- # +# Pure helpers: allowlist + path parsing +# --------------------------------------------------------------------------- # + + +def test_allowlist_accepts_hf_safetensors(): + assert is_url_allowed("https://huggingface.co/x/y/resolve/main/z.safetensors") + + +def test_allowlist_accepts_civitai_pth(): + assert is_url_allowed("https://civitai.com/api/download/models/123.pth") + + +def test_allowlist_rejects_unknown_host(): + assert not is_url_allowed("https://example.com/x.safetensors") + + +def test_allowlist_rejects_api_path_on_hf(): + # On an allowlisted host but not pointing at a model file. + assert not is_url_allowed("https://huggingface.co/api/models") + + +def test_allowlist_rejects_non_https_except_localhost(): + assert not is_url_allowed("http://huggingface.co/x/y.safetensors") + assert is_url_allowed("http://localhost:8000/x.safetensors") + + +def test_parse_model_id_valid(patched_folder_paths): + assert parse_model_id("loras/foo.safetensors") == ("loras", "foo.safetensors") + + +def test_parse_model_id_rejects_traversal(patched_folder_paths): + with pytest.raises(InvalidModelId): + parse_model_id("../etc/passwd") + + +def test_parse_model_id_rejects_unknown_folder(patched_folder_paths): + with pytest.raises(InvalidModelId): + parse_model_id("nope/x.safetensors") + + +def test_parse_model_id_rejects_double_slash(patched_folder_paths): + with pytest.raises(InvalidModelId): + parse_model_id("loras/sub/x.safetensors") + + +def test_resolve_existing_returns_path_when_present(model_root, patched_folder_paths): + _root, loras_dir, _ = model_root + target = loras_dir / "foo.safetensors" + target.write_bytes(b"x") + assert resolve_existing("loras/foo.safetensors") == str(target) + + +def test_resolve_existing_returns_none_when_absent(patched_folder_paths): + assert resolve_existing("loras/missing.safetensors") is None + + +def test_resolve_destination_returns_tmp_pair(model_root, patched_folder_paths): + _root, loras_dir, _ = model_root + final, tmp = resolve_destination("loras/foo.safetensors") + assert final == str(loras_dir / "foo.safetensors") + assert tmp == final + ".tmp" + + +# --------------------------------------------------------------------------- # +# DownloadServer registry: lifecycle, races, cancellation epoch semantics +# --------------------------------------------------------------------------- # + + +def test_register_is_exclusive(): + server = DownloadServer() + s1 = server.try_register("loras/x.safetensors", "https://huggingface.co/a") + s2 = server.try_register("loras/x.safetensors", "https://huggingface.co/b") + assert s1 is not None + assert s2 is None + assert server.is_downloading("loras/x.safetensors") + + +def test_cancel_removes_session(): + server = DownloadServer() + server.try_register("loras/x.safetensors", "https://huggingface.co/a") + assert server.cancel("loras/x.safetensors") is True + assert not server.is_downloading("loras/x.safetensors") + + +def test_cancel_returns_false_when_absent(): + server = DownloadServer() + assert server.cancel("loras/never.safetensors") is False + + +def test_finish_only_clears_matching_epoch(): + """If a session is cancelled and a new one for the same id is + registered, ``finish`` from the original worker must not evict the + newer session.""" + server = DownloadServer() + s_old = server.try_register("loras/x.safetensors", "u1") + server.cancel("loras/x.safetensors") + s_new = server.try_register("loras/x.safetensors", "u2") + assert s_new is not None and s_new.epoch != s_old.epoch + # Old worker's late finish() is a no-op: + server.finish(s_old) + assert server.is_downloading("loras/x.safetensors") + server.finish(s_new) + assert not server.is_downloading("loras/x.safetensors") + + +def test_is_active_follows_cancellation(): + server = DownloadServer() + s = server.try_register("loras/x.safetensors", "u") + assert server.is_active(s) + server.cancel("loras/x.safetensors") + assert not server.is_active(s) + + +def test_update_progress_tracks_fraction(): + server = DownloadServer() + s = server.try_register("loras/x.safetensors", "u") + server.update_progress(s, 50, 100) + snap = server.snapshot()["loras/x.safetensors"] + assert snap.bytes_downloaded == 50 + assert snap.total_bytes == 100 + assert snap.progress == 0.5 + + +def test_update_progress_with_unknown_total_keeps_progress_none(): + server = DownloadServer() + s = server.try_register("loras/x.safetensors", "u") + server.update_progress(s, 50, None) + assert server.snapshot()["loras/x.safetensors"].progress is None + + +def test_cleanup_orphan_tmp_files(model_root): + """Orphan .tmp left by a crashed download must be swept on first use.""" + _root, loras_dir, _ = model_root + orphan = loras_dir / "stale.safetensors.tmp" + orphan.write_bytes(b"partial") + mapping = {"loras": ([str(loras_dir)], {".safetensors"})} + with patch("folder_paths.folder_names_and_paths", mapping), patch( + "folder_paths.get_folder_paths", + side_effect=lambda name: mapping.get(name, ([], set()))[0], + ): + server = DownloadServer() + assert orphan.exists(), "sweep must not run at construction time" + server.sweep_orphan_tmp_files() + assert not orphan.exists() + # Idempotent — a second call is a cheap no-op. + server.sweep_orphan_tmp_files() + + +# --------------------------------------------------------------------------- # +# Route: POST /api/models-availability-status +# --------------------------------------------------------------------------- # + + +async def test_availability_partitions_correctly( + aiohttp_client, app, model_root, fresh_download_server +): + _root, loras_dir, _ = model_root + (loras_dir / "present.safetensors").write_bytes(b"x") + fresh_download_server.try_register( + "loras/inflight.safetensors", "http://localhost:8000/x.safetensors" + ) + client = await aiohttp_client(app) + + # Stub probes — we're testing state assignment, not network calls. + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult( + file_size=None, is_hf_downloadable=None, + )), + ): + body = { + "models": { + "loras/present.safetensors": "http://localhost:8000/p.safetensors", + "loras/missing.safetensors": "http://localhost:8000/m.safetensors", + "loras/inflight.safetensors": "http://localhost:8000/x.safetensors", + } + } + resp = await client.post("/api/models-availability-status", json=body) + assert resp.status == 200 + data = await resp.json() + models = data["models"] + assert models["loras/present.safetensors"]["state"] == "available" + assert models["loras/missing.safetensors"]["state"] == "missing" + assert models["loras/inflight.safetensors"]["state"] == "downloading" + assert "hf_auth" in data + + +async def test_availability_invalid_id_classified_as_missing(aiohttp_client, app): + client = await aiohttp_client(app) + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult( + file_size=None, is_hf_downloadable=None, + )), + ): + resp = await client.post( + "/api/models-availability-status", + json={"models": {"../etc/passwd": "http://localhost:8000/x.safetensors"}}, + ) + assert resp.status == 200 + data = await resp.json() + assert data["models"]["../etc/passwd"]["state"] == "missing" + + +# --------------------------------------------------------------------------- # +# Route: POST /api/download-models — precondition gating +# --------------------------------------------------------------------------- # + + +async def test_download_rejects_url_not_in_allowlist(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": {"loras/x.safetensors": "https://evil.com/x.safetensors"}}, + ) + assert resp.status == 400 + err = (await resp.json())["error"] + assert err["code"] == "URL_NOT_ALLOWED" + + +async def test_download_rejects_already_available( + aiohttp_client, app, model_root +): + _root, loras_dir, _ = model_root + (loras_dir / "x.safetensors").write_bytes(b"x") + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors" + }}, + ) + assert resp.status == 409 + assert (await resp.json())["error"]["code"] == "ALREADY_AVAILABLE" + + +async def test_download_rejects_already_downloading( + aiohttp_client, app, fresh_download_server +): + fresh_download_server.try_register( + "loras/x.safetensors", "https://huggingface.co/u.safetensors" + ) + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors" + }}, + ) + assert resp.status == 409 + assert (await resp.json())["error"]["code"] == "ALREADY_DOWNLOADING" + + +async def test_download_rejects_gated_model(aiohttp_client, app): + client = await aiohttp_client(app) + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult(file_size=None, is_hf_downloadable=False)), + ): + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/x.safetensors": "https://huggingface.co/g/r/resolve/main/x.safetensors" + }}, + ) + assert resp.status == 400 + assert (await resp.json())["error"]["code"] == "MODEL_NOT_DOWNLOADABLE" + + +async def test_download_rejects_invalid_model_id(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": {"../etc/passwd": "https://huggingface.co/x.safetensors"}}, + ) + assert resp.status == 400 + assert (await resp.json())["error"]["code"] == "INVALID_MODEL_ID" + + +async def test_download_atomic_failure_does_not_register_partial( + aiohttp_client, app, model_root, fresh_download_server +): + """If one model in a batch fails, none get registered.""" + _root, loras_dir, _ = model_root + (loras_dir / "already.safetensors").write_bytes(b"x") + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={ + "models": { + "loras/already.safetensors": + "https://huggingface.co/a/b/resolve/main/already.safetensors", + "loras/new.safetensors": + "https://huggingface.co/a/b/resolve/main/new.safetensors", + } + }, + ) + assert resp.status == 409 + # The "new" model should not have been registered as part of the + # failed batch. + assert not fresh_download_server.is_downloading("loras/new.safetensors") + + +async def test_download_schedules_when_all_preconditions_pass( + aiohttp_client, app, fresh_download_server +): + """Verify the precondition pass, registration pass, and async + scheduling all wire up correctly. We patch the streamer to avoid + real HTTP while still letting the route execute end-to-end.""" + started = asyncio.Event() + finish_signal = asyncio.Event() + + async def fake_stream(session): + started.set() + await finish_signal.wait() + from app.model_downloader.download_server import DOWNLOAD_SERVER + DOWNLOAD_SERVER.finish(session) + return "/dev/null" + + with patch( + "app.model_downloader.api.routes.probe_url", + new=AsyncMock(return_value=MetadataProbeResult(file_size=42, is_hf_downloadable=True)), + ), patch( + "app.model_downloader.downloader.stream_to_disk", new=fake_stream + ): + client = await aiohttp_client(app) + resp = await client.post( + "/api/download-models", + json={"models": { + "loras/new.safetensors": + "https://huggingface.co/a/b/resolve/main/new.safetensors" + }}, + ) + assert resp.status == 202 + body = await resp.json() + assert body["accepted"] is True + assert body["scheduled"] == ["loras/new.safetensors"] + # Wait for the worker to actually start. + await asyncio.wait_for(started.wait(), timeout=2.0) + assert fresh_download_server.is_downloading("loras/new.safetensors") + finish_signal.set() + + +# --------------------------------------------------------------------------- # +# Route: POST /api/cancel-model-download-session +# --------------------------------------------------------------------------- # + + +async def test_cancel_removes_active_session( + aiohttp_client, app, fresh_download_server +): + fresh_download_server.try_register( + "loras/x.safetensors", "https://huggingface.co/u.safetensors" + ) + client = await aiohttp_client(app) + resp = await client.post( + "/api/cancel-model-download-session", + json={"model_id": "loras/x.safetensors"}, + ) + assert resp.status == 200 + assert (await resp.json())["cancelled"] is True + assert not fresh_download_server.is_downloading("loras/x.safetensors") + + +async def test_cancel_returns_404_when_none(aiohttp_client, app): + client = await aiohttp_client(app) + resp = await client.post( + "/api/cancel-model-download-session", + json={"model_id": "loras/nothing.safetensors"}, + ) + assert resp.status == 404 + assert (await resp.json())["error"]["code"] == "NOT_DOWNLOADING" + + +# --------------------------------------------------------------------------- # +# Unified availability response embeds metadata per id +# --------------------------------------------------------------------------- # + + +async def test_availability_embeds_metadata(aiohttp_client, app): + """``file_size`` + ``is_hf_downloadable`` come back on the same + request as the state — no separate metadata endpoint.""" + results = { + "https://huggingface.co/a/b/resolve/main/free.safetensors": + MetadataProbeResult(file_size=1024, is_hf_downloadable=True), + "https://huggingface.co/g/r/resolve/main/gated.safetensors": + MetadataProbeResult(file_size=None, is_hf_downloadable=False), + } + + async def fake_probe(url): + return results[url] + + with patch( + "app.model_downloader.api.routes.probe_url", new=fake_probe + ): + client = await aiohttp_client(app) + resp = await client.post( + "/api/models-availability-status", + json={ + "models": { + "loras/free.safetensors": + "https://huggingface.co/a/b/resolve/main/free.safetensors", + "loras/gated.safetensors": + "https://huggingface.co/g/r/resolve/main/gated.safetensors", + } + }, + ) + assert resp.status == 200 + models = (await resp.json())["models"] + assert models["loras/free.safetensors"]["file_size"] == 1024 + assert models["loras/free.safetensors"]["is_hf_downloadable"] is True + assert models["loras/gated.safetensors"]["file_size"] is None + assert models["loras/gated.safetensors"]["is_hf_downloadable"] is False