ComfyUI/tests-unit/model_downloader_test/test_engine_integration.py
2026-06-30 20:33:15 +02:00

271 lines
8.9 KiB
Python

"""Integration tests for the download engine against a local aiohttp server.
Covers single-stream and segmented transfers, deterministic resume from a
partial file, and cancel rollback. Async tests are driven via ``asyncio.run``
so no pytest-asyncio plugin is required.
"""
from __future__ import annotations
import asyncio
import os
import uuid
import pytest
from aiohttp import web
from comfy.cli_args import args
from app.model_downloader.constants import DownloadStatus
from app.model_downloader.database import queries
from app.model_downloader.engine.job import DownloadJob, JobSpec
from app.model_downloader.net.session import close_session
from app.model_downloader.security import paths
PAYLOAD_ETAG = '"v1"'
def _payload(n: int) -> bytes:
return bytes((i * 37 + 11) % 256 for i in range(n))
def _range_handler(payload: bytes):
async def handler(request: web.Request) -> web.Response:
rng = request.headers.get("Range")
if rng:
spec = rng.split("=", 1)[1]
s, _, e = spec.partition("-")
start = int(s)
end = int(e) if e else len(payload) - 1
chunk = payload[start : end + 1]
return web.Response(
status=206,
body=chunk,
headers={
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
"Accept-Ranges": "bytes",
"ETag": PAYLOAD_ETAG,
},
)
return web.Response(
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
)
return handler
def _noranges_handler(payload: bytes):
async def handler(request: web.Request) -> web.Response:
# Always full body, never advertises Accept-Ranges -> single-stream.
return web.Response(status=200, body=payload)
return handler
def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01):
async def handler(request: web.Request) -> web.StreamResponse:
resp = web.StreamResponse(
status=200, headers={"Content-Length": str(len(payload))}
)
await resp.prepare(request)
for i in range(0, len(payload), chunk):
await resp.write(payload[i : i + chunk])
await asyncio.sleep(delay)
await resp.write_eof()
return resp
return handler
async def _serve(handler):
app = web.Application()
app.router.add_route("*", "/{name:.*}", handler)
runner = web.AppRunner(app)
await runner.setup()
site = web.TCPSite(runner, "127.0.0.1", 0)
await site.start()
port = site._server.sockets[0].getsockname()[1]
return runner, port
def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tuple[str, str, str]:
final_path, temp_path = paths.resolve_destination(model_id)
download_id = str(uuid.uuid4())
queries.insert_download(
{
"id": download_id,
"url": url,
"model_id": model_id,
"dest_path": final_path,
"temp_path": temp_path,
"status": status,
}
)
return download_id, final_path, temp_path
# ----- single-stream -----
def test_single_stream_download(model_root):
payload = _payload(300_000)
async def _run():
await close_session()
runner, port = await _serve(_noranges_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, _temp = _insert("loras/single.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/single.safetensors",
dest_path=final_path, temp_path=_temp,
))
status = await job.run()
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
assert os.path.exists(final_path)
assert open(final_path, "rb").read() == payload
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- segmented -----
def test_segmented_download(model_root):
payload = _payload(4 * 1024 * 1024) # 4 MiB -> multiple segments
async def _run():
await close_session()
runner, port = await _serve(_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/seg.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/seg.safetensors",
dest_path=final_path, temp_path=temp,
))
status = await job.run()
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
assert open(final_path, "rb").read() == payload
# More than one segment row was planned.
assert len(queries.list_segments(did)) > 1
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- deterministic resume from a partial file -----
def test_resume_from_partial(model_root):
payload = _payload(512 * 1024) # < 1 MiB -> single segment, but ranges work
async def _run():
await close_session()
runner, port = await _serve(_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/resume.safetensors", url)
# Simulate a prior partial: first 200 KiB already written, offset persisted.
prefix = 200 * 1024
os.makedirs(os.path.dirname(temp), exist_ok=True)
with open(temp, "wb") as f:
f.write(payload[:prefix])
queries.update_download(did, bytes_done=prefix, etag=PAYLOAD_ETAG)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/resume.safetensors",
dest_path=final_path, temp_path=temp, etag=PAYLOAD_ETAG,
))
status = await job.run()
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
assert open(final_path, "rb").read() == payload
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- cancel rollback -----
def test_cancel_rollback(model_root, monkeypatch):
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
payload = _payload(1024 * 1024)
async def _run():
await close_session()
runner, port = await _serve(_slow_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/cancel.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/cancel.safetensors",
dest_path=final_path, temp_path=temp,
))
task = asyncio.ensure_future(job.run())
# Wait until some bytes have been written, then cancel.
for _ in range(200):
await asyncio.sleep(0.01)
if job.state.bytes_done > 0:
break
job.request_cancel()
status = await task
assert status == DownloadStatus.CANCELLED
assert not os.path.exists(temp)
assert not os.path.exists(final_path)
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- manager + scheduler end-to-end -----
def test_manager_enqueue_to_completion(model_root):
payload = _payload(2 * 1024 * 1024)
async def _run():
await close_session()
from app.model_downloader.manager import DOWNLOAD_MANAGER
runner, port = await _serve(_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/e2e.safetensors")
# Wait for completion.
final_path, _ = paths.resolve_destination("loras/e2e.safetensors")
for _ in range(500):
await asyncio.sleep(0.02)
row = queries.get_download(did)
if row.status in DownloadStatus.TERMINAL:
break
row = queries.get_download(did)
assert row.status == DownloadStatus.COMPLETED, row.error
assert open(final_path, "rb").read() == payload
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_manager_rejects_disallowed_url(model_root):
async def _run():
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
with pytest.raises(DownloadError) as ei:
await DOWNLOAD_MANAGER.enqueue(
"https://evil.example.com/x.safetensors", "loras/bad.safetensors"
)
assert ei.value.code == "URL_NOT_ALLOWED"
asyncio.run(_run())