ComfyUI/tests-unit/app_test/model_downloader_test.py
DoronGenzelHass dcd61f1132 test(model_downloader): unit tests for downloads, allowlist, paths, and HF auth
Co-Authored-By: Claude Opus 4.8 (1M context) <noreply@anthropic.com>
2026-06-22 15:22:11 +03:00

505 lines
18 KiB
Python

"""Unit tests for the server-side model download subsystem.
Covers the pieces that don't require talking to a real network:
- path parsing & allowlist (pure functions)
- DownloadServer registry lifecycle (in-memory state)
- API routes via aiohttp_client + folder_paths/probe_url patches
Streaming downloads themselves are exercised indirectly — the route-level
tests stub out the network probe so we can verify the gating logic in
``download_models`` without making real HTTP calls.
"""
from __future__ import annotations
import asyncio
import os
from unittest.mock import patch, AsyncMock
import pytest
from aiohttp import web
from app.model_downloader.allowlist import is_url_allowed
from app.model_downloader.api.routes import register_routes
from app.model_downloader.download_server import DownloadServer
from app.model_downloader.gated_detection import MetadataProbeResult
from app.model_downloader.paths import (
InvalidModelId,
parse_model_id,
resolve_destination,
resolve_existing,
)
# Global asyncio mark: the sync tests below trigger a cosmetic
# PytestWarning for each one because pytest-asyncio applies the mark
# indiscriminately. Other tests in this repo (see custom_node_manager_test.py)
# use the same pattern. The warnings are noise, not failures.
pytestmark = pytest.mark.asyncio
# --------------------------------------------------------------------------- #
# Fixtures
# --------------------------------------------------------------------------- #
@pytest.fixture
def model_root(tmp_path):
"""A fake ``models/`` root with two registered folder types."""
loras_dir = tmp_path / "loras"
checkpoints_dir = tmp_path / "checkpoints"
loras_dir.mkdir()
checkpoints_dir.mkdir()
return tmp_path, loras_dir, checkpoints_dir
@pytest.fixture
def patched_folder_paths(model_root):
"""Point folder_paths at our fake roots for the duration of one test."""
_root, loras_dir, checkpoints_dir = model_root
mapping = {
"loras": ([str(loras_dir)], {".safetensors"}),
"checkpoints": ([str(checkpoints_dir)], {".safetensors"}),
}
with patch(
"folder_paths.folder_names_and_paths", mapping
), patch(
"folder_paths.get_folder_paths",
side_effect=lambda name: mapping.get(name, ([], set()))[0],
):
yield mapping
@pytest.fixture
def fresh_download_server():
"""Reset the module-level singleton between tests so registry state
doesn't leak across tests sharing the singleton."""
from app.model_downloader.download_server import DOWNLOAD_SERVER
DOWNLOAD_SERVER.reset_for_tests()
yield DOWNLOAD_SERVER
DOWNLOAD_SERVER.reset_for_tests()
@pytest.fixture
def app(patched_folder_paths, fresh_download_server):
app = web.Application()
register_routes(app)
return app
# --------------------------------------------------------------------------- #
# Pure helpers: allowlist + path parsing
# --------------------------------------------------------------------------- #
def test_allowlist_accepts_hf_safetensors():
assert is_url_allowed("https://huggingface.co/x/y/resolve/main/z.safetensors")
def test_allowlist_accepts_civitai_pth():
assert is_url_allowed("https://civitai.com/api/download/models/123.pth")
def test_allowlist_rejects_unknown_host():
assert not is_url_allowed("https://example.com/x.safetensors")
def test_allowlist_rejects_api_path_on_hf():
# On an allowlisted host but not pointing at a model file.
assert not is_url_allowed("https://huggingface.co/api/models")
def test_allowlist_rejects_non_https_except_localhost():
assert not is_url_allowed("http://huggingface.co/x/y.safetensors")
assert is_url_allowed("http://localhost:8000/x.safetensors")
def test_parse_model_id_valid(patched_folder_paths):
assert parse_model_id("loras/foo.safetensors") == ("loras", "foo.safetensors")
def test_parse_model_id_rejects_traversal(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("../etc/passwd")
def test_parse_model_id_rejects_unknown_folder(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("nope/x.safetensors")
def test_parse_model_id_rejects_double_slash(patched_folder_paths):
with pytest.raises(InvalidModelId):
parse_model_id("loras/sub/x.safetensors")
def test_resolve_existing_returns_path_when_present(model_root, patched_folder_paths):
_root, loras_dir, _ = model_root
target = loras_dir / "foo.safetensors"
target.write_bytes(b"x")
assert resolve_existing("loras/foo.safetensors") == str(target)
def test_resolve_existing_returns_none_when_absent(patched_folder_paths):
assert resolve_existing("loras/missing.safetensors") is None
def test_resolve_destination_returns_tmp_pair(model_root, patched_folder_paths):
_root, loras_dir, _ = model_root
final, tmp = resolve_destination("loras/foo.safetensors")
assert final == str(loras_dir / "foo.safetensors")
assert tmp == final + ".tmp"
# --------------------------------------------------------------------------- #
# DownloadServer registry: lifecycle, races, cancellation epoch semantics
# --------------------------------------------------------------------------- #
def test_register_is_exclusive():
server = DownloadServer()
s1 = server.try_register("loras/x.safetensors", "https://huggingface.co/a")
s2 = server.try_register("loras/x.safetensors", "https://huggingface.co/b")
assert s1 is not None
assert s2 is None
assert server.is_downloading("loras/x.safetensors")
def test_cancel_removes_session():
server = DownloadServer()
server.try_register("loras/x.safetensors", "https://huggingface.co/a")
assert server.cancel("loras/x.safetensors") is True
assert not server.is_downloading("loras/x.safetensors")
def test_cancel_returns_false_when_absent():
server = DownloadServer()
assert server.cancel("loras/never.safetensors") is False
def test_finish_only_clears_matching_epoch():
"""If a session is cancelled and a new one for the same id is
registered, ``finish`` from the original worker must not evict the
newer session."""
server = DownloadServer()
s_old = server.try_register("loras/x.safetensors", "u1")
server.cancel("loras/x.safetensors")
s_new = server.try_register("loras/x.safetensors", "u2")
assert s_new is not None and s_new.epoch != s_old.epoch
# Old worker's late finish() is a no-op:
server.finish(s_old)
assert server.is_downloading("loras/x.safetensors")
server.finish(s_new)
assert not server.is_downloading("loras/x.safetensors")
def test_is_active_follows_cancellation():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
assert server.is_active(s)
server.cancel("loras/x.safetensors")
assert not server.is_active(s)
def test_update_progress_tracks_fraction():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
server.update_progress(s, 50, 100)
snap = server.snapshot()["loras/x.safetensors"]
assert snap.bytes_downloaded == 50
assert snap.total_bytes == 100
assert snap.progress == 0.5
def test_update_progress_with_unknown_total_keeps_progress_none():
server = DownloadServer()
s = server.try_register("loras/x.safetensors", "u")
server.update_progress(s, 50, None)
assert server.snapshot()["loras/x.safetensors"].progress is None
def test_cleanup_orphan_tmp_files(model_root):
"""Orphan .tmp left by a crashed download must be swept on first use."""
_root, loras_dir, _ = model_root
orphan = loras_dir / "stale.safetensors.tmp"
orphan.write_bytes(b"partial")
mapping = {"loras": ([str(loras_dir)], {".safetensors"})}
with patch("folder_paths.folder_names_and_paths", mapping), patch(
"folder_paths.get_folder_paths",
side_effect=lambda name: mapping.get(name, ([], set()))[0],
):
server = DownloadServer()
assert orphan.exists(), "sweep must not run at construction time"
server.sweep_orphan_tmp_files()
assert not orphan.exists()
# Idempotent — a second call is a cheap no-op.
server.sweep_orphan_tmp_files()
# --------------------------------------------------------------------------- #
# Route: POST /api/models-availability-status
# --------------------------------------------------------------------------- #
async def test_availability_partitions_correctly(
aiohttp_client, app, model_root, fresh_download_server
):
_root, loras_dir, _ = model_root
(loras_dir / "present.safetensors").write_bytes(b"x")
fresh_download_server.try_register(
"loras/inflight.safetensors", "http://localhost:8000/x.safetensors"
)
client = await aiohttp_client(app)
# Stub probes — we're testing state assignment, not network calls.
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(
file_size=None, is_hf_downloadable=None,
)),
):
body = {
"models": {
"loras/present.safetensors": "http://localhost:8000/p.safetensors",
"loras/missing.safetensors": "http://localhost:8000/m.safetensors",
"loras/inflight.safetensors": "http://localhost:8000/x.safetensors",
}
}
resp = await client.post("/api/models-availability-status", json=body)
assert resp.status == 200
data = await resp.json()
models = data["models"]
assert models["loras/present.safetensors"]["state"] == "available"
assert models["loras/missing.safetensors"]["state"] == "missing"
assert models["loras/inflight.safetensors"]["state"] == "downloading"
assert "hf_auth" in data
async def test_availability_invalid_id_classified_as_missing(aiohttp_client, app):
client = await aiohttp_client(app)
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(
file_size=None, is_hf_downloadable=None,
)),
):
resp = await client.post(
"/api/models-availability-status",
json={"models": {"../etc/passwd": "http://localhost:8000/x.safetensors"}},
)
assert resp.status == 200
data = await resp.json()
assert data["models"]["../etc/passwd"]["state"] == "missing"
# --------------------------------------------------------------------------- #
# Route: POST /api/download-models — precondition gating
# --------------------------------------------------------------------------- #
async def test_download_rejects_url_not_in_allowlist(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {"loras/x.safetensors": "https://evil.com/x.safetensors"}},
)
assert resp.status == 400
err = (await resp.json())["error"]
assert err["code"] == "URL_NOT_ALLOWED"
async def test_download_rejects_already_available(
aiohttp_client, app, model_root
):
_root, loras_dir, _ = model_root
(loras_dir / "x.safetensors").write_bytes(b"x")
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
}},
)
assert resp.status == 409
assert (await resp.json())["error"]["code"] == "ALREADY_AVAILABLE"
async def test_download_rejects_already_downloading(
aiohttp_client, app, fresh_download_server
):
fresh_download_server.try_register(
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
)
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/a/b/resolve/main/x.safetensors"
}},
)
assert resp.status == 409
assert (await resp.json())["error"]["code"] == "ALREADY_DOWNLOADING"
async def test_download_rejects_gated_model(aiohttp_client, app):
client = await aiohttp_client(app)
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(file_size=None, is_hf_downloadable=False)),
):
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/x.safetensors": "https://huggingface.co/g/r/resolve/main/x.safetensors"
}},
)
assert resp.status == 400
assert (await resp.json())["error"]["code"] == "MODEL_NOT_DOWNLOADABLE"
async def test_download_rejects_invalid_model_id(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {"../etc/passwd": "https://huggingface.co/x.safetensors"}},
)
assert resp.status == 400
assert (await resp.json())["error"]["code"] == "INVALID_MODEL_ID"
async def test_download_atomic_failure_does_not_register_partial(
aiohttp_client, app, model_root, fresh_download_server
):
"""If one model in a batch fails, none get registered."""
_root, loras_dir, _ = model_root
(loras_dir / "already.safetensors").write_bytes(b"x")
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={
"models": {
"loras/already.safetensors":
"https://huggingface.co/a/b/resolve/main/already.safetensors",
"loras/new.safetensors":
"https://huggingface.co/a/b/resolve/main/new.safetensors",
}
},
)
assert resp.status == 409
# The "new" model should not have been registered as part of the
# failed batch.
assert not fresh_download_server.is_downloading("loras/new.safetensors")
async def test_download_schedules_when_all_preconditions_pass(
aiohttp_client, app, fresh_download_server
):
"""Verify the precondition pass, registration pass, and async
scheduling all wire up correctly. We patch the streamer to avoid
real HTTP while still letting the route execute end-to-end."""
started = asyncio.Event()
finish_signal = asyncio.Event()
async def fake_stream(session):
started.set()
await finish_signal.wait()
from app.model_downloader.download_server import DOWNLOAD_SERVER
DOWNLOAD_SERVER.finish(session)
return "/dev/null"
with patch(
"app.model_downloader.api.routes.probe_url",
new=AsyncMock(return_value=MetadataProbeResult(file_size=42, is_hf_downloadable=True)),
), patch(
"app.model_downloader.downloader.stream_to_disk", new=fake_stream
):
client = await aiohttp_client(app)
resp = await client.post(
"/api/download-models",
json={"models": {
"loras/new.safetensors":
"https://huggingface.co/a/b/resolve/main/new.safetensors"
}},
)
assert resp.status == 202
body = await resp.json()
assert body["accepted"] is True
assert body["scheduled"] == ["loras/new.safetensors"]
# Wait for the worker to actually start.
await asyncio.wait_for(started.wait(), timeout=2.0)
assert fresh_download_server.is_downloading("loras/new.safetensors")
finish_signal.set()
# --------------------------------------------------------------------------- #
# Route: POST /api/cancel-model-download-session
# --------------------------------------------------------------------------- #
async def test_cancel_removes_active_session(
aiohttp_client, app, fresh_download_server
):
fresh_download_server.try_register(
"loras/x.safetensors", "https://huggingface.co/u.safetensors"
)
client = await aiohttp_client(app)
resp = await client.post(
"/api/cancel-model-download-session",
json={"model_id": "loras/x.safetensors"},
)
assert resp.status == 200
assert (await resp.json())["cancelled"] is True
assert not fresh_download_server.is_downloading("loras/x.safetensors")
async def test_cancel_returns_404_when_none(aiohttp_client, app):
client = await aiohttp_client(app)
resp = await client.post(
"/api/cancel-model-download-session",
json={"model_id": "loras/nothing.safetensors"},
)
assert resp.status == 404
assert (await resp.json())["error"]["code"] == "NOT_DOWNLOADING"
# --------------------------------------------------------------------------- #
# Unified availability response embeds metadata per id
# --------------------------------------------------------------------------- #
async def test_availability_embeds_metadata(aiohttp_client, app):
"""``file_size`` + ``is_hf_downloadable`` come back on the same
request as the state — no separate metadata endpoint."""
results = {
"https://huggingface.co/a/b/resolve/main/free.safetensors":
MetadataProbeResult(file_size=1024, is_hf_downloadable=True),
"https://huggingface.co/g/r/resolve/main/gated.safetensors":
MetadataProbeResult(file_size=None, is_hf_downloadable=False),
}
async def fake_probe(url):
return results[url]
with patch(
"app.model_downloader.api.routes.probe_url", new=fake_probe
):
client = await aiohttp_client(app)
resp = await client.post(
"/api/models-availability-status",
json={
"models": {
"loras/free.safetensors":
"https://huggingface.co/a/b/resolve/main/free.safetensors",
"loras/gated.safetensors":
"https://huggingface.co/g/r/resolve/main/gated.safetensors",
}
},
)
assert resp.status == 200
models = (await resp.json())["models"]
assert models["loras/free.safetensors"]["file_size"] == 1024
assert models["loras/free.safetensors"]["is_hf_downloadable"] is True
assert models["loras/gated.safetensors"]["file_size"] is None
assert models["loras/gated.safetensors"]["is_hf_downloadable"] is False