diff --git a/app/model_downloader/engine/job.py b/app/model_downloader/engine/job.py index 942558678..499e41c81 100644 --- a/app/model_downloader/engine/job.py +++ b/app/model_downloader/engine/job.py @@ -422,24 +422,54 @@ class DownloadJob: self._set_status(DownloadStatus.VERIFYING) total = self.state.total_bytes - actual_size = os.path.getsize(self.spec.temp_path) - if total is not None and actual_size != total: + segmented = len(self.state.segments) > 1 + if segmented: + # The .part was preallocated to total_bytes, so its on-disk size is + # not evidence of completeness: a segment that ends short (truncated + # 206 / server closes mid-range) leaves a zero-filled hole while the + # file size still equals total. Verify each segment wrote its full + # planned range, and trust the byte counter (== sum of segments) + # rather than os.path.getsize for the total check. + for seg in self.state.segments: + if seg.bytes_done != seg.length: + raise FatalError( + f"segment {seg.idx} incomplete: wrote {seg.bytes_done} " + f"of {seg.length} bytes" + ) + observed = self.state.bytes_done + else: + # Single-stream writes a contiguous prefix, so the on-disk size is + # an independent witness of how much actually landed. + observed = os.path.getsize(self.spec.temp_path) + if total is not None and observed != total: raise FatalError( - f"size mismatch: wrote {actual_size} of {total} bytes" + f"size mismatch: wrote {observed} of {total} bytes" ) # Structural gate (cheap, no full read) then optional sha256 (full read). - await asyncio.to_thread(structural.validate, self.spec.temp_path) - if self.spec.expected_sha256: + # Both failures are non-retryable (a truncated/corrupt or mismatched file + # will not heal on retry), so surface them as FatalError rather than + # letting the plain Exceptions fall through to the retryable handler. + # ``temp_path`` carries the ``.part`` suffix; pass ``dest_path`` so the + # structural check detects the real file format instead of skipping it. + try: await asyncio.to_thread( - checksum.verify_sha256, self.spec.temp_path, self.spec.expected_sha256 + structural.validate, self.spec.temp_path, self.spec.dest_path ) + if self.spec.expected_sha256: + await asyncio.to_thread( + checksum.verify_sha256, + self.spec.temp_path, + self.spec.expected_sha256, + ) + except (structural.StructuralError, checksum.ChecksumError) as e: + raise FatalError(str(e)) from e os.makedirs(os.path.dirname(self.spec.dest_path), exist_ok=True) os.replace(self.spec.temp_path, self.spec.dest_path) logging.info( "[model_downloader] completed %s (%d bytes)", - self.spec.model_id, actual_size, + self.spec.model_id, observed, ) # Catalog into the assets system (blake3 dedup identity). Best-effort. await dedup.register_completed(self.spec.dest_path) diff --git a/app/model_downloader/verify/structural.py b/app/model_downloader/verify/structural.py index 5575ff145..7b403a31a 100644 --- a/app/model_downloader/verify/structural.py +++ b/app/model_downloader/verify/structural.py @@ -12,6 +12,7 @@ from __future__ import annotations import json import os import struct +from typing import Optional _SAFETENSORS_EXTS = (".safetensors", ".sft") # A sane upper bound so a corrupt header length can't make us read gigabytes. @@ -22,9 +23,15 @@ class StructuralError(Exception): """The file failed its structural integrity check.""" -def validate(path: str) -> None: - """Validate the file at ``path``. Raises :class:`StructuralError` on failure.""" - lower = path.lower() +def validate(path: str, name_hint: Optional[str] = None) -> None: + """Validate the file at ``path``. Raises :class:`StructuralError` on failure. + + The file format is detected from ``name_hint`` when provided, otherwise from + ``path``. Callers that download into a temp file with an opaque suffix (e.g. + ``*.comfy-download.part``) must pass the final destination name as + ``name_hint`` so the format check is not silently skipped. + """ + lower = (name_hint or path).lower() if lower.endswith(_SAFETENSORS_EXTS): _validate_safetensors(path) # No structural check for other formats; the size + (optional) checksum diff --git a/tests-unit/model_downloader_test/test_engine_integration.py b/tests-unit/model_downloader_test/test_engine_integration.py index e9977d6d0..4a4982e18 100644 --- a/tests-unit/model_downloader_test/test_engine_integration.py +++ b/tests-unit/model_downloader_test/test_engine_integration.py @@ -8,7 +8,9 @@ so no pytest-asyncio plugin is required. from __future__ import annotations import asyncio +import json import os +import struct import uuid import pytest @@ -28,6 +30,34 @@ def _payload(n: int) -> bytes: return bytes((i * 37 + 11) % 256 for i in range(n)) +def _safetensors_payload(total: int) -> bytes: + """A structurally valid ``.safetensors`` blob of exactly ``total`` bytes. + + Success-path tests download to ``.safetensors`` destinations, which the + engine now structurally validates before the atomic rename, so their + payloads must parse as real safetensors (header length + JSON header + + data region whose size matches the declared ``data_offsets``). + """ + def _header(data_len: int) -> bytes: + return json.dumps( + {"w": {"dtype": "U8", "shape": [data_len], "data_offsets": [0, data_len]}} + ).encode("utf-8") + + # The header's byte length depends on the digit count of ``data_len``, so + # iterate until ``total == 8 + len(header) + data_len`` is self-consistent. + data_len = total - 8 - len(_header(total)) + for _ in range(8): + header = _header(data_len) + new_data_len = total - 8 - len(header) + if new_data_len == data_len: + break + data_len = new_data_len + assert data_len >= 0, "total too small for a safetensors payload" + header = _header(data_len) + body = bytes((i * 37 + 11) % 256 for i in range(data_len)) + return struct.pack(" web.Response: rng = request.headers.get("Range") @@ -104,6 +134,41 @@ def _overflow_range_handler(payload: bytes, extra: int = 256 * 1024): return handler +def _short_range_handler(payload: bytes, drop: int = 64 * 1024): + """A 206 server that returns fewer bytes than requested for later segments. + + Simulates a server cleanly closing a range connection early. The response + is internally consistent (Content-Length matches the short body), so the + client sees no error and the segment just ends short, leaving a zero-filled + hole in the preallocated file. + """ + + async def handler(request: web.Request) -> web.Response: + rng = request.headers.get("Range") + if rng: + spec = rng.split("=", 1)[1] + s, _, e = spec.partition("-") + start = int(s) + end = int(e) if e else len(payload) - 1 + chunk = payload[start : end + 1] + if start > 0 and len(chunk) > drop: + chunk = chunk[:-drop] # truncate a non-first segment + return web.Response( + status=206, + body=chunk, + headers={ + "Content-Range": f"bytes {start}-{end}/{len(payload)}", + "Accept-Ranges": "bytes", + "ETag": PAYLOAD_ETAG, + }, + ) + return web.Response( + status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG} + ) + + return handler + + def _unbounded_handler(total: int, chunk: int = 16384): """A 200 stream with no Content-Length / Accept-Ranges (unknown length).""" @@ -151,7 +216,7 @@ def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tup def test_single_stream_download(model_root): - payload = _payload(300_000) + payload = _safetensors_payload(300_000) async def _run(): await close_session() @@ -178,7 +243,7 @@ def test_single_stream_download(model_root): def test_segmented_download(model_root): - payload = _payload(4 * 1024 * 1024) # 4 MiB -> multiple segments + payload = _safetensors_payload(4 * 1024 * 1024) # 4 MiB -> multiple segments async def _run(): await close_session() @@ -206,7 +271,7 @@ def test_segmented_download(model_root): def test_resume_from_partial(model_root): - payload = _payload(512 * 1024) # < 1 MiB -> single segment, but ranges work + payload = _safetensors_payload(512 * 1024) # < 1 MiB -> single segment async def _run(): await close_session() @@ -298,6 +363,67 @@ def test_segment_overflow_aborts(model_root): asyncio.run(_run()) +def test_short_segment_fails_closed(model_root): + """A segment that ends short must fail, not be accepted as complete. + + The file is preallocated to total_bytes, so the on-disk size still equals + total even with a zero-filled hole; completeness must be judged per-segment. + """ + payload = _safetensors_payload(4 * 1024 * 1024) # large enough to segment + + async def _run(): + await close_session() + runner, port = await _serve(_short_range_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/short.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/short.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.FAILED, queries.get_download(did).error + assert "incomplete" in (queries.get_download(did).error or "") + assert not os.path.exists(final_path) + assert not os.path.exists(temp) + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + +def test_structural_validation_rejects_corrupt(model_root): + """A correctly sized but structurally invalid file fails closed (not retried). + + Regression for the dead structural gate: validation must key off the + destination extension, not the ``.part`` temp suffix. + """ + payload = _payload(300_000) # right size, but not a valid safetensors blob + + async def _run(): + await close_session() + runner, port = await _serve(_noranges_handler(payload)) + try: + url = f"http://127.0.0.1:{port}/model.safetensors" + did, final_path, temp = _insert("loras/corrupt.safetensors", url) + job = DownloadJob(JobSpec( + download_id=did, url=url, model_id="loras/corrupt.safetensors", + dest_path=final_path, temp_path=temp, + )) + status = await job.run() + assert status == DownloadStatus.FAILED, queries.get_download(did).error + assert not os.path.exists(final_path) + assert not os.path.exists(temp) + # Failed closed at first attempt, not re-queued as retryable. + assert queries.get_download(did).attempts == 0 + finally: + await runner.cleanup() + await close_session() + + asyncio.run(_run()) + + def test_rejects_oversized_known_download(model_root, monkeypatch): """A file whose advertised size exceeds the cap is rejected at probe.""" monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False) @@ -353,7 +479,7 @@ def test_unknown_length_capped_by_max_bytes(model_root, monkeypatch): def test_manager_enqueue_to_completion(model_root): - payload = _payload(2 * 1024 * 1024) + payload = _safetensors_payload(2 * 1024 * 1024) async def _run(): await close_session()