mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Add max-download-size in case the server tries to send larger files than it reports.
This commit is contained in:
parent
95b0758a88
commit
4ae294d2d5
@ -199,6 +199,13 @@ class DownloadJob:
|
||||
raise RetryableError(pr.error or "probe failed")
|
||||
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
|
||||
|
||||
max_bytes = self._max_download_bytes()
|
||||
if max_bytes is not None and pr.total_bytes is not None and pr.total_bytes > max_bytes:
|
||||
raise FatalError(
|
||||
f"file size {pr.total_bytes} exceeds the maximum allowed "
|
||||
f"download size {max_bytes} (--download-max-bytes)"
|
||||
)
|
||||
|
||||
self._etag = pr.etag or self._etag
|
||||
self.state.total_bytes = pr.total_bytes
|
||||
queries.update_download(
|
||||
@ -318,11 +325,29 @@ class DownloadJob:
|
||||
self._raise_for_status(resp.status)
|
||||
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
||||
self._check_control()
|
||||
# Never write past this segment's planned range: a
|
||||
# non-conforming 206 that returns more than the requested
|
||||
# bytes would otherwise overrun adjacent segments and the
|
||||
# preallocated file. Cap the write and abort on overflow.
|
||||
remaining = seg.length - seg.bytes_done
|
||||
if remaining <= 0:
|
||||
raise FatalError(
|
||||
f"segment {seg.idx}: server returned more than the "
|
||||
f"requested {seg.length} bytes"
|
||||
)
|
||||
overflow = len(chunk) > remaining
|
||||
if overflow:
|
||||
chunk = chunk[:remaining]
|
||||
await self._writer.write_at(offset, chunk)
|
||||
offset += len(chunk)
|
||||
seg.bytes_done += len(chunk)
|
||||
self._recompute_bytes_done()
|
||||
await self._persist_progress()
|
||||
if overflow:
|
||||
raise FatalError(
|
||||
f"segment {seg.idx}: server returned more than the "
|
||||
f"requested {seg.length} bytes"
|
||||
)
|
||||
|
||||
async def _run_single(self) -> None:
|
||||
seg = self.state.segments[0]
|
||||
@ -343,13 +368,37 @@ class DownloadJob:
|
||||
self._raise_for_status(resp.status)
|
||||
elif offset == 0 and resp.status != 200:
|
||||
self._raise_for_status(resp.status)
|
||||
# Byte ceiling for this stream: the known total when the server
|
||||
# reported a size, otherwise the configured maximum download size.
|
||||
# Without a bound, a non-conforming response or an unknown-length
|
||||
# stream (end == -1) that never closes could fill the disk (DoS).
|
||||
limit = (seg.end + 1) if seg.end >= 0 else self._max_download_bytes()
|
||||
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
||||
self._check_control()
|
||||
overflow = False
|
||||
if limit is not None:
|
||||
remaining = limit - offset
|
||||
if remaining <= 0:
|
||||
raise FatalError(
|
||||
f"download exceeded the maximum size {limit} bytes"
|
||||
)
|
||||
if len(chunk) > remaining:
|
||||
chunk = chunk[:remaining]
|
||||
overflow = True
|
||||
await self._writer.write_at(offset, chunk)
|
||||
offset += len(chunk)
|
||||
seg.bytes_done = offset
|
||||
self.state.bytes_done = offset
|
||||
await self._persist_progress()
|
||||
if overflow:
|
||||
raise FatalError(
|
||||
f"download exceeded the maximum size {limit} bytes"
|
||||
)
|
||||
|
||||
def _max_download_bytes(self) -> Optional[int]:
|
||||
"""Configured maximum download size in bytes, or ``None`` if disabled."""
|
||||
cap = getattr(args, "download_max_bytes", 0)
|
||||
return cap if cap and cap > 0 else None
|
||||
|
||||
def _raise_for_status(self, status: int) -> None:
|
||||
if status in (401, 403):
|
||||
|
||||
@ -248,6 +248,7 @@ parser.add_argument("--download-segments", type=int, default=8, metavar="N", hel
|
||||
parser.add_argument("--download-max-active", type=int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).")
|
||||
parser.add_argument("--download-max-connections-per-host", type=int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).")
|
||||
parser.add_argument("--download-chunk-size", type=int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).")
|
||||
parser.add_argument("--download-max-bytes", type=int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).")
|
||||
parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.")
|
||||
parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).")
|
||||
|
||||
|
||||
@ -76,6 +76,50 @@ def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01):
|
||||
return handler
|
||||
|
||||
|
||||
def _overflow_range_handler(payload: bytes, extra: int = 256 * 1024):
|
||||
"""A non-conforming 206 server that returns MORE than the requested range."""
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
# Maliciously overrun: append extra bytes past the requested end.
|
||||
body = payload[start : end + 1] + bytes(extra)
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=body,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
},
|
||||
)
|
||||
return web.Response(
|
||||
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _unbounded_handler(total: int, chunk: int = 16384):
|
||||
"""A 200 stream with no Content-Length / Accept-Ranges (unknown length)."""
|
||||
|
||||
async def handler(request: web.Request) -> web.StreamResponse:
|
||||
resp = web.StreamResponse(status=200)
|
||||
await resp.prepare(request)
|
||||
sent = 0
|
||||
while sent < total:
|
||||
await resp.write(bytes(min(chunk, total - sent)))
|
||||
sent += chunk
|
||||
await resp.write_eof()
|
||||
return resp
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
async def _serve(handler):
|
||||
app = web.Application()
|
||||
app.router.add_route("*", "/{name:.*}", handler)
|
||||
@ -226,6 +270,85 @@ def test_cancel_rollback(model_root, monkeypatch):
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- size-bound enforcement (malicious / non-conforming hosts) -----
|
||||
|
||||
|
||||
def test_segment_overflow_aborts(model_root):
|
||||
"""A 206 returning more than the requested range must not overrun."""
|
||||
payload = _payload(4 * 1024 * 1024) # large enough to segment
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_overflow_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/overflow.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/overflow.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_rejects_oversized_known_download(model_root, monkeypatch):
|
||||
"""A file whose advertised size exceeds the cap is rejected at probe."""
|
||||
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
|
||||
payload = _payload(300_000)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_noranges_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/toobig.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/toobig.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED
|
||||
assert not os.path.exists(final_path)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_unknown_length_capped_by_max_bytes(model_root, monkeypatch):
|
||||
"""An unbounded unknown-length stream is capped by --download-max-bytes."""
|
||||
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
|
||||
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_unbounded_handler(2 * 1024 * 1024))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/unbounded.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/unbounded.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
# ----- manager + scheduler end-to-end -----
|
||||
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user