diff --git a/app/model_downloader/engine/job.py b/app/model_downloader/engine/job.py index ef1d5552f..ad05192b2 100644 --- a/app/model_downloader/engine/job.py +++ b/app/model_downloader/engine/job.py @@ -199,6 +199,13 @@ class DownloadJob: raise RetryableError(pr.error or "probe failed") raise FatalError(pr.error or f"probe returned HTTP {pr.status}") + max_bytes = self._max_download_bytes() + if max_bytes is not None and pr.total_bytes is not None and pr.total_bytes > max_bytes: + raise FatalError( + f"file size {pr.total_bytes} exceeds the maximum allowed " + f"download size {max_bytes} (--download-max-bytes)" + ) + self._etag = pr.etag or self._etag self.state.total_bytes = pr.total_bytes queries.update_download( @@ -318,11 +325,29 @@ class DownloadJob: self._raise_for_status(resp.status) async for chunk in resp.content.iter_chunked(args.download_chunk_size): self._check_control() + # Never write past this segment's planned range: a + # non-conforming 206 that returns more than the requested + # bytes would otherwise overrun adjacent segments and the + # preallocated file. Cap the write and abort on overflow. + remaining = seg.length - seg.bytes_done + if remaining <= 0: + raise FatalError( + f"segment {seg.idx}: server returned more than the " + f"requested {seg.length} bytes" + ) + overflow = len(chunk) > remaining + if overflow: + chunk = chunk[:remaining] await self._writer.write_at(offset, chunk) offset += len(chunk) seg.bytes_done += len(chunk) self._recompute_bytes_done() await self._persist_progress() + if overflow: + raise FatalError( + f"segment {seg.idx}: server returned more than the " + f"requested {seg.length} bytes" + ) async def _run_single(self) -> None: seg = self.state.segments[0] @@ -343,13 +368,37 @@ class DownloadJob: self._raise_for_status(resp.status) elif offset == 0 and resp.status != 200: self._raise_for_status(resp.status) + # Byte ceiling for this stream: the known total when the server + # reported a size, otherwise the configured maximum download size. + # Without a bound, a non-conforming response or an unknown-length + # stream (end == -1) that never closes could fill the disk (DoS). + limit = (seg.end + 1) if seg.end >= 0 else self._max_download_bytes() async for chunk in resp.content.iter_chunked(args.download_chunk_size): self._check_control() + overflow = False + if limit is not None: + remaining = limit - offset + if remaining <= 0: + raise FatalError( + f"download exceeded the maximum size {limit} bytes" + ) + if len(chunk) > remaining: + chunk = chunk[:remaining] + overflow = True await self._writer.write_at(offset, chunk) offset += len(chunk) seg.bytes_done = offset self.state.bytes_done = offset await self._persist_progress() + if overflow: + raise FatalError( + f"download exceeded the maximum size {limit} bytes" + ) + + def _max_download_bytes(self) -> Optional[int]: + """Configured maximum download size in bytes, or ``None`` if disabled.""" + cap = getattr(args, "download_max_bytes", 0) + return cap if cap and cap > 0 else None def _raise_for_status(self, status: int) -> None: if status in (401, 403): diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 9130d5f15..afa4f5e6a 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -248,6 +248,7 @@ parser.add_argument("--download-segments", type=int, default=8, metavar="N", hel parser.add_argument("--download-max-active", type=int, default=3, metavar="N", help="Maximum number of model downloads running concurrently (default: 3).") parser.add_argument("--download-max-connections-per-host", type=int, default=16, metavar="N", help="Maximum simultaneous connections to a single host for the download manager (default: 16).") parser.add_argument("--download-chunk-size", type=int, default=4 * 1024 * 1024, metavar="BYTES", help="Read chunk size in bytes for the download manager (default: 4 MiB).") +parser.add_argument("--download-max-bytes", type=int, default=1024 * 1024 * 1024 * 1024, metavar="BYTES", help="Maximum size in bytes of a single download; aborts transfers that exceed it (guards against malicious/non-conforming hosts filling the disk). Set to 0 to disable (default: 1 TiB).") parser.add_argument("--download-allowed-hosts", type=str, nargs="*", default=[], metavar="HOST", help="Additional hostnames to add to the download manager allowlist (https only). The built-in defaults always include huggingface.co and civitai.com.") parser.add_argument("--download-allow-any-extension", action="store_true", help="Allow the download manager to fetch files with any extension (default: only known model extensions like .safetensors).") diff --git a/tests-unit/model_downloader_test/test_engine_integration.py b/tests-unit/model_downloader_test/test_engine_integration.py index 8818d6c8a..e9977d6d0 100644 --- a/tests-unit/model_downloader_test/test_engine_integration.py +++ b/tests-unit/model_downloader_test/test_engine_integration.py @@ -76,6 +76,50 @@ def _slow_handler(payload: bytes, chunk: int = 16384, delay: float = 0.01): return handler +def _overflow_range_handler(payload: bytes, extra: int = 256 * 1024): + """A non-conforming 206 server that returns MORE than the requested range.""" + + async def handler(request: web.Request) -> web.Response: + rng = request.headers.get("Range") + if rng: + spec = rng.split("=", 1)[1] + s, _, e = spec.partition("-") + start = int(s) + end = int(e) if e else len(payload) - 1 + # Maliciously overrun: append extra bytes past the requested end. + body = payload[start : end + 1] + bytes(extra) + return web.Response( + status=206, + body=body, + headers={ + "Content-Range": f"bytes {start}-{end}/{len(payload)}", + "Accept-Ranges": "bytes", + "ETag": PAYLOAD_ETAG, + }, + ) + return web.Response( + status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG} + ) + + return handler + + +def _unbounded_handler(total: int, chunk: int = 16384): + """A 200 stream with no Content-Length / Accept-Ranges (unknown length).""" + + async def handler(request: web.Request) -> web.StreamResponse: + resp = web.StreamResponse(status=200) + await resp.prepare(request) + sent = 0 + while sent < total: + await resp.write(bytes(min(chunk, total - sent))) + sent += chunk + await resp.write_eof() + return resp + + return handler + + async def _serve(handler): app = web.Application() app.router.add_route("*", "/{name:.*}", handler) @@ -226,6 +270,85 @@ def test_cancel_rollback(model_root, monkeypatch): asyncio.run(_run()) +# ----- size-bound enforcement (malicious / non-conforming hosts) ----- + + +def test_segment_overflow_aborts(model_root): + """A 206 returning more than the requested range must not overrun.""" + payload = _payload(4 * 1024 * 1024) # large enough to segment + + async def _run(): + await close_session() + runner, port = await _serve(_overflow_range_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/overflow.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/overflow.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.FAILED + assert not os.path.exists(final_path) + assert not os.path.exists(temp) + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_rejects_oversized_known_download(model_root, monkeypatch): + """A file whose advertised size exceeds the cap is rejected at probe.""" + monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False) + payload = _payload(300_000) + + async def _run(): + await close_session() + runner, port = await _serve(_noranges_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/toobig.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/toobig.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.FAILED + assert not os.path.exists(final_path) + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_unknown_length_capped_by_max_bytes(model_root, monkeypatch): + """An unbounded unknown-length stream is capped by --download-max-bytes.""" + monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False) + monkeypatch.setattr(args, "download_chunk_size", 16384, raising=False) + + async def _run(): + await close_session() + runner, port = await _serve(_unbounded_handler(2 * 1024 * 1024)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/unbounded.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/unbounded.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.FAILED + assert not os.path.exists(final_path) + assert not os.path.exists(temp) + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + # ----- manager + scheduler end-to-end -----