mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Add max-download-size in case the server tries to send larger files than it reports.
This commit is contained in:
parent
95b0758a88
commit
4ae294d2d5
@ -199,6 +199,13 @@ class DownloadJob:
|
|||||||
raise RetryableError(pr.error or "probe failed")
|
raise RetryableError(pr.error or "probe failed")
|
||||||
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
|
raise FatalError(pr.error or f"probe returned HTTP {pr.status}")
|
||||||
|
|
||||||
|
max_bytes = self._max_download_bytes()
|
||||||
|
if max_bytes is not None and pr.total_bytes is not None and pr.total_bytes > max_bytes:
|
||||||
|
raise FatalError(
|
||||||
|
f"file size {pr.total_bytes} exceeds the maximum allowed "
|
||||||
|
f"download size {max_bytes} (--download-max-bytes)"
|
||||||
|
)
|
||||||
|
|
||||||
self._etag = pr.etag or self._etag
|
self._etag = pr.etag or self._etag
|
||||||
self.state.total_bytes = pr.total_bytes
|
self.state.total_bytes = pr.total_bytes
|
||||||
queries.update_download(
|
queries.update_download(
|
||||||
@ -318,11 +325,29 @@ class DownloadJob:
|
|||||||
self._raise_for_status(resp.status)
|
self._raise_for_status(resp.status)
|
||||||
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
||||||
self._check_control()
|
self._check_control()
|
||||||
|
# Never write past this segment's planned range: a
|
||||||
|
# non-conforming 206 that returns more than the requested
|
||||||
|
# bytes would otherwise overrun adjacent segments and the
|
||||||
|
# preallocated file. Cap the write and abort on overflow.
|
||||||
|
remaining = seg.length - seg.bytes_done
|
||||||
|
if remaining <= 0:
|
||||||
|
raise FatalError(
|
||||||
|
f"segment {seg.idx}: server returned more than the "
|
||||||
|
f"requested {seg.length} bytes"
|
||||||
|
)
|
||||||
|
overflow = len(chunk) > remaining
|
||||||
|
if overflow:
|
||||||
|
chunk = chunk[:remaining]
|
||||||
await self._writer.write_at(offset, chunk)
|
await self._writer.write_at(offset, chunk)
|
||||||
offset += len(chunk)
|
offset += len(chunk)
|
||||||
seg.bytes_done += len(chunk)
|
seg.bytes_done += len(chunk)
|
||||||
self._recompute_bytes_done()
|
self._recompute_bytes_done()
|
||||||
await self._persist_progress()
|
await self._persist_progress()
|
||||||
|
if overflow:
|
||||||
|
raise FatalError(
|
||||||
|
f"segment {seg.idx}: server returned more than the "
|
||||||
|
f"requested {seg.length} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
async def _run_single(self) -> None:
|
async def _run_single(self) -> None:
|
||||||
seg = self.state.segments[0]
|
seg = self.state.segments[0]
|
||||||
@ -343,13 +368,37 @@ class DownloadJob:
|
|||||||
self._raise_for_status(resp.status)
|
self._raise_for_status(resp.status)
|
||||||
elif offset == 0 and resp.status != 200:
|
elif offset == 0 and resp.status != 200:
|
||||||
self._raise_for_status(resp.status)
|
self._raise_for_status(resp.status)
|
||||||
|
# Byte ceiling for this stream: the known total when the server
|
||||||
|
# reported a size, otherwise the configured maximum download size.
|
||||||
|
# Without a bound, a non-conforming response or an unknown-length
|
||||||
|
# stream (end == -1) that never closes could fill the disk (DoS).
|
||||||
|
limit = (seg.end + 1) if seg.end >= 0 else self._max_download_bytes()
|
||||||
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
async for chunk in resp.content.iter_chunked(args.download_chunk_size):
|
||||||
self._check_control()
|
self._check_control()
|
||||||
|
overflow = False
|
||||||
|
if limit is not None:
|
||||||
|
remaining = limit - offset
|
||||||
|
if remaining <= 0:
|
||||||
|
raise FatalError(
|
||||||
|
f"download exceeded the maximum size {limit} bytes"
|
||||||
|
)
|
||||||
|
if len(chunk) > remaining:
|
||||||
|
chunk = chunk[:remaining]
|
||||||
|
overflow = True
|
||||||
await self._writer.write_at(offset, chunk)
|
await self._writer.write_at(offset, chunk)
|
||||||
offset += len(chunk)
|
offset += len(chunk)
|
||||||
seg.bytes_done = offset
|
seg.bytes_done = offset
|
||||||
self.state.bytes_done = offset
|
self.state.bytes_done = offset
|
||||||
await self._persist_progress()
|
await self._persist_progress()
|
||||||
|
if overflow:
|
||||||
|
raise FatalError(
|
||||||
|
f"download exceeded the maximum size {limit} bytes"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _max_download_bytes(self) -> Optional[int]:
|
||||||
|
"""Configured maximum download size in bytes, or ``None`` if disabled."""
|
||||||
|
cap = getattr(args, "download_max_bytes", 0)
|
||||||
|
return cap if cap and cap > 0 else None
|
||||||
|
|
||||||
def _raise_for_status(self, status: int) -> None:
|
def _raise_for_status(self, status: int) -> None:
|
||||||
if status in (401, 403):
|
if status in (401, 403):
|
||||||
|
|||||||
@ -248,6 +248,7 @@ parser.add_argument("--download-segments", type=int, default=8, metavar="N", hel
|
|||||||
parser.add_argument("--download-max-active", type=int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).")
|
parser.add_argument("--download-max-active", type=int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).")
|
||||||
parser.add_argument("--download-max-connections-per-host", type=int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).")
|
parser.add_argument("--download-max-connections-per-host", type=int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).")
|
||||||
parser.add_argument("--download-chunk-size", type=int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).")
|
parser.add_argument("--download-chunk-size", type=int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).")
|
||||||
|
parser.add_argument("--download-max-bytes", type=int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).")
|
||||||
parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.")
|
parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.")
|
||||||
parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).")
|
parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).")
|
||||||
|
|
||||||
|
|||||||
@ -76,6 +76,50 @@ def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01):
|
|||||||
return handler
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
def _overflow_range_handler(payload: bytes, extra: int = 256 * 1024):
|
||||||
|
"""A non-conforming 206 server that returns MORE than the requested range."""
|
||||||
|
|
||||||
|
async def handler(request: web.Request) -> web.Response:
|
||||||
|
rng = request.headers.get("Range")
|
||||||
|
if rng:
|
||||||
|
spec = rng.split("=", 1)[1]
|
||||||
|
s, _, e = spec.partition("-")
|
||||||
|
start = int(s)
|
||||||
|
end = int(e) if e else len(payload) - 1
|
||||||
|
# Maliciously overrun: append extra bytes past the requested end.
|
||||||
|
body = payload[start : end + 1] + bytes(extra)
|
||||||
|
return web.Response(
|
||||||
|
status=206,
|
||||||
|
body=body,
|
||||||
|
headers={
|
||||||
|
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
|
||||||
|
"Accept-Ranges": "bytes",
|
||||||
|
"ETag": PAYLOAD_ETAG,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
return web.Response(
|
||||||
|
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
|
||||||
|
)
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
|
def _unbounded_handler(total: int, chunk: int = 16384):
|
||||||
|
"""A 200 stream with no Content-Length / Accept-Ranges (unknown length)."""
|
||||||
|
|
||||||
|
async def handler(request: web.Request) -> web.StreamResponse:
|
||||||
|
resp = web.StreamResponse(status=200)
|
||||||
|
await resp.prepare(request)
|
||||||
|
sent = 0
|
||||||
|
while sent < total:
|
||||||
|
await resp.write(bytes(min(chunk, total - sent)))
|
||||||
|
sent += chunk
|
||||||
|
await resp.write_eof()
|
||||||
|
return resp
|
||||||
|
|
||||||
|
return handler
|
||||||
|
|
||||||
|
|
||||||
async def _serve(handler):
|
async def _serve(handler):
|
||||||
app = web.Application()
|
app = web.Application()
|
||||||
app.router.add_route("*", "/{name:.*}", handler)
|
app.router.add_route("*", "/{name:.*}", handler)
|
||||||
@ -226,6 +270,85 @@ def test_cancel_rollback(model_root, monkeypatch):
|
|||||||
asyncio.run(_run())
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
# ----- size-bound enforcement (malicious / non-conforming hosts) -----
|
||||||
|
|
||||||
|
|
||||||
|
def test_segment_overflow_aborts(model_root):
|
||||||
|
"""A 206 returning more than the requested range must not overrun."""
|
||||||
|
payload = _payload(4 * 1024 * 1024) # large enough to segment
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
await close_session()
|
||||||
|
runner, port = await _serve(_overflow_range_handler(payload))
|
||||||
|
try:
|
||||||
|
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||||
|
did, final_path, temp = _insert("loras/overflow.safetensors", url)
|
||||||
|
job = DownloadJob(JobSpec(
|
||||||
|
download_id=did, url=url, model_id="loras/overflow.safetensors",
|
||||||
|
dest_path=final_path, temp_path=temp,
|
||||||
|
))
|
||||||
|
status = await job.run()
|
||||||
|
assert status == DownloadStatus.FAILED
|
||||||
|
assert not os.path.exists(final_path)
|
||||||
|
assert not os.path.exists(temp)
|
||||||
|
finally:
|
||||||
|
await runner.cleanup()
|
||||||
|
await close_session()
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_rejects_oversized_known_download(model_root, monkeypatch):
|
||||||
|
"""A file whose advertised size exceeds the cap is rejected at probe."""
|
||||||
|
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
|
||||||
|
payload = _payload(300_000)
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
await close_session()
|
||||||
|
runner, port = await _serve(_noranges_handler(payload))
|
||||||
|
try:
|
||||||
|
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||||
|
did, final_path, temp = _insert("loras/toobig.safetensors", url)
|
||||||
|
job = DownloadJob(JobSpec(
|
||||||
|
download_id=did, url=url, model_id="loras/toobig.safetensors",
|
||||||
|
dest_path=final_path, temp_path=temp,
|
||||||
|
))
|
||||||
|
status = await job.run()
|
||||||
|
assert status == DownloadStatus.FAILED
|
||||||
|
assert not os.path.exists(final_path)
|
||||||
|
finally:
|
||||||
|
await runner.cleanup()
|
||||||
|
await close_session()
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
|
def test_unknown_length_capped_by_max_bytes(model_root, monkeypatch):
|
||||||
|
"""An unbounded unknown-length stream is capped by --download-max-bytes."""
|
||||||
|
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
|
||||||
|
monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False)
|
||||||
|
|
||||||
|
async def _run():
|
||||||
|
await close_session()
|
||||||
|
runner, port = await _serve(_unbounded_handler(2 * 1024 * 1024))
|
||||||
|
try:
|
||||||
|
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||||
|
did, final_path, temp = _insert("loras/unbounded.safetensors", url)
|
||||||
|
job = DownloadJob(JobSpec(
|
||||||
|
download_id=did, url=url, model_id="loras/unbounded.safetensors",
|
||||||
|
dest_path=final_path, temp_path=temp,
|
||||||
|
))
|
||||||
|
status = await job.run()
|
||||||
|
assert status == DownloadStatus.FAILED
|
||||||
|
assert not os.path.exists(final_path)
|
||||||
|
assert not os.path.exists(temp)
|
||||||
|
finally:
|
||||||
|
await runner.cleanup()
|
||||||
|
await close_session()
|
||||||
|
|
||||||
|
asyncio.run(_run())
|
||||||
|
|
||||||
|
|
||||||
# ----- manager + scheduler end-to-end -----
|
# ----- manager + scheduler end-to-end -----
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user