mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
271 lines
8.9 KiB
Python
271 lines
8.9 KiB
Python
"""Integration tests for the download engine against a local aiohttp server.
|
|
|
|
Covers single-stream and segmented transfers, deterministic resume from a
|
|
partial file, and cancel rollback. Async tests are driven via ``asyncio.run``
|
|
so no pytest-asyncio plugin is required.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import asyncio
|
|
import os
|
|
import uuid
|
|
|
|
import pytest
|
|
from aiohttp import web
|
|
|
|
from comfy.cli_args import args
|
|
from app.model_downloader.constants import DownloadStatus
|
|
from app.model_downloader.database import queries
|
|
from app.model_downloader.engine.job import DownloadJob, JobSpec
|
|
from app.model_downloader.net.session import close_session
|
|
from app.model_downloader.security import paths
|
|
|
|
PAYLOAD_ETAG = '"v1"'
|
|
|
|
|
|
def _payload(n: int) -> bytes:
|
|
return bytes((i * 37 + 11) % 256 for i in range(n))
|
|
|
|
|
|
def _range_handler(payload: bytes):
|
|
async def handler(request: web.Request) -> web.Response:
|
|
rng = request.headers.get("Range")
|
|
if rng:
|
|
spec = rng.split("=", 1)[1]
|
|
s, _, e = spec.partition("-")
|
|
start = int(s)
|
|
end = int(e) if e else len(payload) - 1
|
|
chunk = payload[start : end + 1]
|
|
return web.Response(
|
|
status=206,
|
|
body=chunk,
|
|
headers={
|
|
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
|
|
"Accept-Ranges": "bytes",
|
|
"ETag": PAYLOAD_ETAG,
|
|
},
|
|
)
|
|
return web.Response(
|
|
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
|
|
)
|
|
|
|
return handler
|
|
|
|
|
|
def _noranges_handler(payload: bytes):
|
|
async def handler(request: web.Request) -> web.Response:
|
|
# Always full body, never advertises Accept-Ranges -> single-stream.
|
|
return web.Response(status=200, body=payload)
|
|
|
|
return handler
|
|
|
|
|
|
def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01):
|
|
async def handler(request: web.Request) -> web.StreamResponse:
|
|
resp = web.StreamResponse(
|
|
status=200, headers={"Content-Length": str(len(payload))}
|
|
)
|
|
await resp.prepare(request)
|
|
for i in range(0, len(payload), chunk):
|
|
await resp.write(payload[i : i + chunk])
|
|
await asyncio.sleep(delay)
|
|
await resp.write_eof()
|
|
return resp
|
|
|
|
return handler
|
|
|
|
|
|
async def _serve(handler):
|
|
app = web.Application()
|
|
app.router.add_route("*", "/{name:.*}", handler)
|
|
runner = web.AppRunner(app)
|
|
await runner.setup()
|
|
site = web.TCPSite(runner, "127.0.0.1", 0)
|
|
await site.start()
|
|
port = site._server.sockets[0].getsockname()[1]
|
|
return runner, port
|
|
|
|
|
|
def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tuple[str, str, str]:
|
|
final_path, temp_path = paths.resolve_destination(model_id)
|
|
download_id = str(uuid.uuid4())
|
|
queries.insert_download(
|
|
{
|
|
"id": download_id,
|
|
"url": url,
|
|
"model_id": model_id,
|
|
"dest_path": final_path,
|
|
"temp_path": temp_path,
|
|
"status": status,
|
|
}
|
|
)
|
|
return download_id, final_path, temp_path
|
|
|
|
|
|
# ----- single-stream -----
|
|
|
|
|
|
def test_single_stream_download(model_root):
|
|
payload = _payload(300_000)
|
|
|
|
async def _run():
|
|
await close_session()
|
|
runner, port = await _serve(_noranges_handler(payload))
|
|
try:
|
|
url = f"http://127.0.0.1:{port}/model.safetensors"
|
|
did, final_path, _temp = _insert("loras/single.safetensors", url)
|
|
job = DownloadJob(JobSpec(
|
|
download_id=did, url=url, model_id="loras/single.safetensors",
|
|
dest_path=final_path, temp_path=_temp,
|
|
))
|
|
status = await job.run()
|
|
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
|
assert os.path.exists(final_path)
|
|
assert open(final_path, "rb").read() == payload
|
|
finally:
|
|
await runner.cleanup()
|
|
await close_session()
|
|
|
|
asyncio.run(_run())
|
|
|
|
|
|
# ----- segmented -----
|
|
|
|
|
|
def test_segmented_download(model_root):
|
|
payload = _payload(4 * 1024 * 1024) # 4 MiB -> multiple segments
|
|
|
|
async def _run():
|
|
await close_session()
|
|
runner, port = await _serve(_range_handler(payload))
|
|
try:
|
|
url = f"http://127.0.0.1:{port}/model.safetensors"
|
|
did, final_path, temp = _insert("loras/seg.safetensors", url)
|
|
job = DownloadJob(JobSpec(
|
|
download_id=did, url=url, model_id="loras/seg.safetensors",
|
|
dest_path=final_path, temp_path=temp,
|
|
))
|
|
status = await job.run()
|
|
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
|
assert open(final_path, "rb").read() == payload
|
|
# More than one segment row was planned.
|
|
assert len(queries.list_segments(did)) > 1
|
|
finally:
|
|
await runner.cleanup()
|
|
await close_session()
|
|
|
|
asyncio.run(_run())
|
|
|
|
|
|
# ----- deterministic resume from a partial file -----
|
|
|
|
|
|
def test_resume_from_partial(model_root):
|
|
payload = _payload(512 * 1024) # < 1 MiB -> single segment, but ranges work
|
|
|
|
async def _run():
|
|
await close_session()
|
|
runner, port = await _serve(_range_handler(payload))
|
|
try:
|
|
url = f"http://127.0.0.1:{port}/model.safetensors"
|
|
did, final_path, temp = _insert("loras/resume.safetensors", url)
|
|
# Simulate a prior partial: first 200 KiB already written, offset persisted.
|
|
prefix = 200 * 1024
|
|
os.makedirs(os.path.dirname(temp), exist_ok=True)
|
|
with open(temp, "wb") as f:
|
|
f.write(payload[:prefix])
|
|
queries.update_download(did, bytes_done=prefix, etag=PAYLOAD_ETAG)
|
|
|
|
job = DownloadJob(JobSpec(
|
|
download_id=did, url=url, model_id="loras/resume.safetensors",
|
|
dest_path=final_path, temp_path=temp, etag=PAYLOAD_ETAG,
|
|
))
|
|
status = await job.run()
|
|
assert status == DownloadStatus.COMPLETED, queries.get_download(did).error
|
|
assert open(final_path, "rb").read() == payload
|
|
finally:
|
|
await runner.cleanup()
|
|
await close_session()
|
|
|
|
asyncio.run(_run())
|
|
|
|
|
|
# ----- cancel rollback -----
|
|
|
|
|
|
def test_cancel_rollback(model_root, monkeypatch):
|
|
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
|
|
payload = _payload(1024 * 1024)
|
|
|
|
async def _run():
|
|
await close_session()
|
|
runner, port = await _serve(_slow_handler(payload))
|
|
try:
|
|
url = f"http://127.0.0.1:{port}/model.safetensors"
|
|
did, final_path, temp = _insert("loras/cancel.safetensors", url)
|
|
job = DownloadJob(JobSpec(
|
|
download_id=did, url=url, model_id="loras/cancel.safetensors",
|
|
dest_path=final_path, temp_path=temp,
|
|
))
|
|
task = asyncio.ensure_future(job.run())
|
|
# Wait until some bytes have been written, then cancel.
|
|
for _ in range(200):
|
|
await asyncio.sleep(0.01)
|
|
if job.state.bytes_done > 0:
|
|
break
|
|
job.request_cancel()
|
|
status = await task
|
|
assert status == DownloadStatus.CANCELLED
|
|
assert not os.path.exists(temp)
|
|
assert not os.path.exists(final_path)
|
|
finally:
|
|
await runner.cleanup()
|
|
await close_session()
|
|
|
|
asyncio.run(_run())
|
|
|
|
|
|
# ----- manager + scheduler end-to-end -----
|
|
|
|
|
|
def test_manager_enqueue_to_completion(model_root):
|
|
payload = _payload(2 * 1024 * 1024)
|
|
|
|
async def _run():
|
|
await close_session()
|
|
from app.model_downloader.manager import DOWNLOAD_MANAGER
|
|
|
|
runner, port = await _serve(_range_handler(payload))
|
|
try:
|
|
url = f"http://127.0.0.1:{port}/model.safetensors"
|
|
did = await DOWNLOAD_MANAGER.enqueue(url, "loras/e2e.safetensors")
|
|
# Wait for completion.
|
|
final_path, _ = paths.resolve_destination("loras/e2e.safetensors")
|
|
for _ in range(500):
|
|
await asyncio.sleep(0.02)
|
|
row = queries.get_download(did)
|
|
if row.status in DownloadStatus.TERMINAL:
|
|
break
|
|
row = queries.get_download(did)
|
|
assert row.status == DownloadStatus.COMPLETED, row.error
|
|
assert open(final_path, "rb").read() == payload
|
|
finally:
|
|
await runner.cleanup()
|
|
await close_session()
|
|
|
|
asyncio.run(_run())
|
|
|
|
|
|
def test_manager_rejects_disallowed_url(model_root):
|
|
async def _run():
|
|
from app.model_downloader.manager import DOWNLOAD_MANAGER, DownloadError
|
|
|
|
with pytest.raises(DownloadError) as ei:
|
|
await DOWNLOAD_MANAGER.enqueue(
|
|
"https://evil.example.com/x.safetensors", "loras/bad.safetensors"
|
|
)
|
|
assert ei.value.code == "URL_NOT_ALLOWED"
|
|
|
|
asyncio.run(_run())
|