Add max-download-size in case the server tries to send larger files than it reports.

This commit is contained in:
Talmaj Marinc 2026-06-30 10:15:47 +02:00
parent 95b0758a88
commit 4ae294d2d5
3 changed files with 173 additions and 0 deletions

View File

@ -199,6 +199,13 @@ class DownloadJob:
raise RetryableError(pr.error or "probe failed")
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
max_bytes = self._max_download_bytes()
if max_bytes is not None and pr.total_bytes is not None and pr.total_bytes > max_bytes:
raise FatalError(
f"file size {pr.total_bytes} exceeds the maximum allowed "
f"download size {max_bytes} (--download-max-bytes)"
)
self._etag = pr.etag or self._etag
self.state.total_bytes = pr.total_bytes
queries.update_download(
@ -318,11 +325,29 @@ class DownloadJob:
self._raise_for_status(resp.status)
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
self._check_control()
# Never write past this segment's planned range: a
# non-conforming 206 that returns more than the requested
# bytes would otherwise overrun adjacent segments and the
# preallocated file. Cap the write and abort on overflow.
remaining = seg.length - seg.bytes_done
if remaining <= 0:
raise FatalError(
f"segment {seg.idx}: server returned more than the "
f"requested {seg.length} bytes"
)
overflow = len(chunk) > remaining
if overflow:
chunk = chunk[:remaining]
await self._writer.write_at(offset, chunk)
offset += len(chunk)
seg.bytes_done += len(chunk)
self._recompute_bytes_done()
await self._persist_progress()
if overflow:
raise FatalError(
f"segment {seg.idx}: server returned more than the "
f"requested {seg.length} bytes"
)
async def _run_single(self) -> None:
seg = self.state.segments[0]
@ -343,13 +368,37 @@ class DownloadJob:
self._raise_for_status(resp.status)
elif offset == 0 and resp.status != 200:
self._raise_for_status(resp.status)
# Byte ceiling for this stream: the known total when the server
# reported a size, otherwise the configured maximum download size.
# Without a bound, a non-conforming response or an unknown-length
# stream (end == -1) that never closes could fill the disk (DoS).
limit = (seg.end + 1) if seg.end >= 0 else self._max_download_bytes()
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
self._check_control()
overflow = False
if limit is not None:
remaining = limit - offset
if remaining <= 0:
raise FatalError(
f"download exceeded the maximum size {limit} bytes"
)
if len(chunk) > remaining:
chunk = chunk[:remaining]
overflow = True
await self._writer.write_at(offset, chunk)
offset += len(chunk)
seg.bytes_done = offset
self.state.bytes_done = offset
await self._persist_progress()
if overflow:
raise FatalError(
f"download exceeded the maximum size {limit} bytes"
)
def _max_download_bytes(self) -> Optional[int]:
"""Configured maximum download size in bytes, or ``None`` if disabled."""
cap = getattr(args, "download_max_bytes", 0)
return cap if cap and cap > 0 else None
def _raise_for_status(self, status: int) -> None:
if status in (401, 403):

View File

@ -248,6 +248,7 @@ parser.add_argument("--download-segments", type=int, default=8, metavar="N", hel
parser.add_argument("--download-max-active", type=int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).")
parser.add_argument("--download-max-connections-per-host", type=int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).")
parser.add_argument("--download-chunk-size", type=int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).")
parser.add_argument("--download-max-bytes", type=int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).")
parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.")
parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).")

View File

@ -76,6 +76,50 @@ def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01):
return handler
def _overflow_range_handler(payload: bytes, extra: int = 256 * 1024):
"""A non-conforming 206 server that returns MORE than the requested range."""
async def handler(request: web.Request) -> web.Response:
rng = request.headers.get("Range")
if rng:
spec = rng.split("=", 1)[1]
s, _, e = spec.partition("-")
start = int(s)
end = int(e) if e else len(payload) - 1
# Maliciously overrun: append extra bytes past the requested end.
body = payload[start : end + 1] + bytes(extra)
return web.Response(
status=206,
body=body,
headers={
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
"Accept-Ranges": "bytes",
"ETag": PAYLOAD_ETAG,
},
)
return web.Response(
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
)
return handler
def _unbounded_handler(total: int, chunk: int = 16384):
"""A 200 stream with no Content-Length / Accept-Ranges (unknown length)."""
async def handler(request: web.Request) -> web.StreamResponse:
resp = web.StreamResponse(status=200)
await resp.prepare(request)
sent = 0
while sent < total:
await resp.write(bytes(min(chunk, total - sent)))
sent += chunk
await resp.write_eof()
return resp
return handler
async def _serve(handler):
app = web.Application()
app.router.add_route("*", "/{name:.*}", handler)
@ -226,6 +270,85 @@ def test_cancel_rollback(model_root, monkeypatch):
asyncio.run(_run())
# ----- size-bound enforcement (malicious / non-conforming hosts) -----
def test_segment_overflow_aborts(model_root):
"""A 206 returning more than the requested range must not overrun."""
payload = _payload(4 * 1024 * 1024) # large enough to segment
async def _run():
await close_session()
runner, port = await _serve(_overflow_range_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/overflow.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/overflow.safetensors",
dest_path=final_path, temp_path=temp,
))
status = await job.run()
assert status == DownloadStatus.FAILED
assert not os.path.exists(final_path)
assert not os.path.exists(temp)
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_rejects_oversized_known_download(model_root, monkeypatch):
"""A file whose advertised size exceeds the cap is rejected at probe."""
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
payload = _payload(300_000)
async def _run():
await close_session()
runner, port = await _serve(_noranges_handler(payload))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/toobig.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/toobig.safetensors",
dest_path=final_path, temp_path=temp,
))
status = await job.run()
assert status == DownloadStatus.FAILED
assert not os.path.exists(final_path)
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
def test_unknown_length_capped_by_max_bytes(model_root, monkeypatch):
"""An unbounded unknown-length stream is capped by --download-max-bytes."""
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
async def _run():
await close_session()
runner, port = await _serve(_unbounded_handler(2 * 1024 * 1024))
try:
url = f"http://127.0.0.1:{port}/model.safetensors"
did, final_path, temp = _insert("loras/unbounded.safetensors", url)
job = DownloadJob(JobSpec(
download_id=did, url=url, model_id="loras/unbounded.safetensors",
dest_path=final_path, temp_path=temp,
))
status = await job.run()
assert status == DownloadStatus.FAILED
assert not os.path.exists(final_path)
assert not os.path.exists(temp)
finally:
await runner.cleanup()
await close_session()
asyncio.run(_run())
# ----- manager + scheduler end-to-end -----