"""Integration tests for the download engine against a local aiohttp server. Covers single-stream and segmented transfers, deterministic resume from a partial file, and cancel rollback. Async tests are driven via ``asyncio.run`` so no pytest-asyncio plugin is required. """ from __future__ import annotations import asyncio import os import uuid import pytest from aiohttp import web from comfy.cli_args import args from app.model_downloader.constants import DownloadStatus from app.model_downloader.database import queries from app.model_downloader.engine.job import DownloadJob, JobSpec from app.model_downloader.net.session import close_session from app.model_downloader.security import paths PAYLOAD_ETAG = '"v1"' def _payload(n: int) -> bytes: return bytes((i * 37 + 11) % 256 for i in range(n)) def _range_handler(payload: bytes): async def handler(request: web.Request) -> web.Response: rng = request.headers.get("Range") if rng: spec = rng.split("=", 1)[1] s, _, e = spec.partition("-") start = int(s) end = int(e) if e else len(payload) - 1 chunk = payload[start : end + 1] return web.Response( status=206, body=chunk, headers={ "Content-Range": f"bytes {start}-{end}/{len(payload)}", "Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG, }, ) return web.Response( status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG} ) return handler def _noranges_handler(payload: bytes): async def handler(request: web.Request) -> web.Response: # Always full body, never advertises Accept-Ranges -> single-stream. return web.Response(status=200, body=payload) return handler def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01): async def handler(request: web.Request) -> web.StreamResponse: resp = web.StreamResponse( status=200, headers={"Content-Length": str(len(payload))} ) await resp.prepare(request) for i in range(0, len(payload), chunk): await resp.write(payload[i : i + chunk]) await asyncio.sleep(delay) await resp.write_eof() return resp return handler async def _serve(handler): app = web.Application() app.router.add_route("*", "/{name:.*}", handler) runner = web.AppRunner(app) await runner.setup() site = web.TCPSite(runner, "127.0.0.1", 0) await site.start() port = site._server.sockets[0].getsockname()[1] return runner, port def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tuple[str, str, str]: final_path, temp_path = paths.resolve_destination(model_id) download_id = str(uuid.uuid4()) queries.insert_download( { "id": download_id, "url": url, "model_id": model_id, "dest_path": final_path, "temp_path": temp_path, "status": status, } ) return download_id, final_path, temp_path # ----- single-stream ----- def test_single_stream_download(model_root): payload = _payload(300_000) async def _run(): await close_session() runner, port = await _serve(_noranges_handler(payload)) try: url = f"http://127.0.0.1:{port}/model.safetensors" did, final_path, _temp = _insert("loras/single.safetensors", url) job = DownloadJob(JobSpec( download_id=did, url=url, model_id="loras/single.safetensors", dest_path=final_path, temp_path=_temp, )) status = await job.run() assert status == DownloadStatus.COMPLETED, queries.get_download(did).error assert os.path.exists(final_path) assert open(final_path, "rb").read() == payload finally: await runner.cleanup() await close_session() asyncio.run(_run()) # ----- segmented ----- def test_segmented_download(model_root): payload = _payload(4 * 1024 * 1024) # 4 MiB -> multiple segments async def _run(): await close_session() runner, port = await _serve(_range_handler(payload)) try: url = f"http://127.0.0.1:{port}/model.safetensors" did, final_path, temp = _insert("loras/seg.safetensors", url) job = DownloadJob(JobSpec( download_id=did, url=url, model_id="loras/seg.safetensors", dest_path=final_path, temp_path=temp, )) status = await job.run() assert status == DownloadStatus.COMPLETED, queries.get_download(did).error assert open(final_path, "rb").read() == payload # More than one segment row was planned. assert len(queries.list_segments(did)) > 1 finally: await runner.cleanup() await close_session() asyncio.run(_run()) # ----- deterministic resume from a partial file ----- def test_resume_from_partial(model_root): payload = _payload(512 * 1024) # < 1 MiB -> single segment, but ranges work async def _run(): await close_session() runner, port = await _serve(_range_handler(payload)) try: url = f"http://127.0.0.1:{port}/model.safetensors" did, final_path, temp = _insert("loras/resume.safetensors", url) # Simulate a prior partial: first 200 KiB already written, offset persisted. prefix = 200 * 1024 os.makedirs(os.path.dirname(temp), exist_ok=True) with open(temp, "wb") as f: f.write(payload[:prefix]) queries.update_download(did, bytes_done=prefix, etag=PAYLOAD_ETAG) job = DownloadJob(JobSpec( download_id=did, url=url, model_id="loras/resume.safetensors", dest_path=final_path, temp_path=temp, etag=PAYLOAD_ETAG, )) status = await job.run() assert status == DownloadStatus.COMPLETED, queries.get_download(did).error assert open(final_path, "rb").read() == payload finally: await runner.cleanup() await close_session() asyncio.run(_run()) # ----- cancel rollback ----- def test_cancel_rollback(model_root, monkeypatch): monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False) payload = _payload(1024 * 1024) async def _run(): await close_session() runner, port = await _serve(_slow_handler(payload)) try: url = f"http://127.0.0.1:{port}/model.safetensors" did, final_path, temp = _insert("loras/cancel.safetensors", url) job = DownloadJob(JobSpec( download_id=did, url=url, model_id="loras/cancel.safetensors", dest_path=final_path, temp_path=temp, )) task = asyncio.ensure_future(job.run()) # Wait until some bytes have been written, then cancel. for _ in range(200): await asyncio.sleep(0.01) if job.state.bytes_done > 0: break job.request_cancel() status = await task assert status == DownloadStatus.CANCELLED assert not os.path.exists(temp) assert not os.path.exists(final_path) finally: await runner.cleanup() await close_session() asyncio.run(_run()) # ----- manager + scheduler end-to-end ----- def test_manager_enqueue_to_completion(model_root): payload = _payload(2 * 1024 * 1024) async def _run(): await close_session() from app.model_downloader.manager import DOWNLOAD_MANAGER runner, port = await _serve(_range_handler(payload)) try: url = f"http://127.0.0.1:{port}/model.safetensors" did = await DOWNLOAD_MANAGER.enqueue(url, "loras/e2e.safetensors") # Wait for completion. final_path, _ = paths.resolve_destination("loras/e2e.safetensors") for _ in range(500): await asyncio.sleep(0.02) row = queries.get_download(did) if row.status in DownloadStatus.TERMINAL: break row = queries.get_download(did) assert row.status == DownloadStatus.COMPLETED, row.error assert open(final_path, "rb").read() == payload finally: await runner.cleanup() await close_session() asyncio.run(_run()) def test_manager_rejects_disallowed_url(model_root): async def _run(): from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError with pytest.raises(DownloadError) as ei: await DOWNLOAD_MANAGER.enqueue( "https://evil.example.com/x.safetensors", "loras/bad.safetensors" ) assert ei.value.code == "URL_NOT_ALLOWED" asyncio.run(_run())