mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 13:19:23 +08:00
Improve _finalize checks for downloads.
This commit is contained in:
parent
58392bf7a6
commit
4ccaaa6f37
@ -422,24 +422,54 @@ class DownloadJob:
|
||||
self._set_status(DownloadStatus.VERIFYING)
|
||||
|
||||
total = self.state.total_bytes
|
||||
actual_size = os.path.getsize(self.spec.temp_path)
|
||||
if total is not None and actual_size != total:
|
||||
segmented = len(self.state.segments) > 1
|
||||
if segmented:
|
||||
# The .part was preallocated to total_bytes, so its on-disk size is
|
||||
# not evidence of completeness: a segment that ends short (truncated
|
||||
# 206 / server closes mid-range) leaves a zero-filled hole while the
|
||||
# file size still equals total. Verify each segment wrote its full
|
||||
# planned range, and trust the byte counter (== sum of segments)
|
||||
# rather than os.path.getsize for the total check.
|
||||
for seg in self.state.segments:
|
||||
if seg.bytes_done != seg.length:
|
||||
raise FatalError(
|
||||
f"segment {seg.idx} incomplete: wrote {seg.bytes_done} "
|
||||
f"of {seg.length} bytes"
|
||||
)
|
||||
observed = self.state.bytes_done
|
||||
else:
|
||||
# Single-stream writes a contiguous prefix, so the on-disk size is
|
||||
# an independent witness of how much actually landed.
|
||||
observed = os.path.getsize(self.spec.temp_path)
|
||||
if total is not None and observed != total:
|
||||
raise FatalError(
|
||||
f"size mismatch: wrote {actual_size} of {total} bytes"
|
||||
f"size mismatch: wrote {observed} of {total} bytes"
|
||||
)
|
||||
|
||||
# Structural gate (cheap, no full read) then optional sha256 (full read).
|
||||
await asyncio.to_thread(structural.validate, self.spec.temp_path)
|
||||
if self.spec.expected_sha256:
|
||||
# Both failures are non-retryable (a truncated/corrupt or mismatched file
|
||||
# will not heal on retry), so surface them as FatalError rather than
|
||||
# letting the plain Exceptions fall through to the retryable handler.
|
||||
# ``temp_path`` carries the ``.part`` suffix; pass ``dest_path`` so the
|
||||
# structural check detects the real file format instead of skipping it.
|
||||
try:
|
||||
await asyncio.to_thread(
|
||||
checksum.verify_sha256, self.spec.temp_path, self.spec.expected_sha256
|
||||
structural.validate, self.spec.temp_path, self.spec.dest_path
|
||||
)
|
||||
if self.spec.expected_sha256:
|
||||
await asyncio.to_thread(
|
||||
checksum.verify_sha256,
|
||||
self.spec.temp_path,
|
||||
self.spec.expected_sha256,
|
||||
)
|
||||
except (structural.StructuralError, checksum.ChecksumError) as e:
|
||||
raise FatalError(str(e)) from e
|
||||
|
||||
os.makedirs(os.path.dirname(self.spec.dest_path), exist_ok=True)
|
||||
os.replace(self.spec.temp_path, self.spec.dest_path)
|
||||
logging.info(
|
||||
"[model_downloader] completed %s (%d bytes)",
|
||||
self.spec.model_id, actual_size,
|
||||
self.spec.model_id, observed,
|
||||
)
|
||||
# Catalog into the assets system (blake3 dedup identity). Best-effort.
|
||||
await dedup.register_completed(self.spec.dest_path)
|
||||
|
||||
@ -12,6 +12,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
from typing import Optional
|
||||
|
||||
_SAFETENSORS_EXTS = (".safetensors", ".sft")
|
||||
# A sane upper bound so a corrupt header length can't make us read gigabytes.
|
||||
@ -22,9 +23,15 @@ class StructuralError(Exception):
|
||||
"""The file failed its structural integrity check."""
|
||||
|
||||
|
||||
def validate(path: str) -> None:
|
||||
"""Validate the file at ``path``. Raises :class:`StructuralError` on failure."""
|
||||
lower = path.lower()
|
||||
def validate(path: str, name_hint: Optional[str] = None) -> None:
|
||||
"""Validate the file at ``path``. Raises :class:`StructuralError` on failure.
|
||||
|
||||
The file format is detected from ``name_hint`` when provided, otherwise from
|
||||
``path``. Callers that download into a temp file with an opaque suffix (e.g.
|
||||
``*.comfy-download.part``) must pass the final destination name as
|
||||
``name_hint`` so the format check is not silently skipped.
|
||||
"""
|
||||
lower = (name_hint or path).lower()
|
||||
if lower.endswith(_SAFETENSORS_EXTS):
|
||||
_validate_safetensors(path)
|
||||
# No structural check for other formats; the size + (optional) checksum
|
||||
|
||||
@ -8,7 +8,9 @@ so no pytest-asyncio plugin is required.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
import struct
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
@ -28,6 +30,34 @@ def _payload(n: int) -> bytes:
|
||||
return bytes((i * 37 + 11) % 256 for i in range(n))
|
||||
|
||||
|
||||
def _safetensors_payload(total: int) -> bytes:
|
||||
"""A structurally valid ``.safetensors`` blob of exactly ``total`` bytes.
|
||||
|
||||
Success-path tests download to ``.safetensors`` destinations, which the
|
||||
engine now structurally validates before the atomic rename, so their
|
||||
payloads must parse as real safetensors (header length + JSON header +
|
||||
data region whose size matches the declared ``data_offsets``).
|
||||
"""
|
||||
def _header(data_len: int) -> bytes:
|
||||
return json.dumps(
|
||||
{"w": {"dtype": "U8", "shape": [data_len], "data_offsets": [0, data_len]}}
|
||||
).encode("utf-8")
|
||||
|
||||
# The header's byte length depends on the digit count of ``data_len``, so
|
||||
# iterate until ``total == 8 + len(header) + data_len`` is self-consistent.
|
||||
data_len = total - 8 - len(_header(total))
|
||||
for _ in range(8):
|
||||
header = _header(data_len)
|
||||
new_data_len = total - 8 - len(header)
|
||||
if new_data_len == data_len:
|
||||
break
|
||||
data_len = new_data_len
|
||||
assert data_len >= 0, "total too small for a safetensors payload"
|
||||
header = _header(data_len)
|
||||
body = bytes((i * 37 + 11) % 256 for i in range(data_len))
|
||||
return struct.pack("<Q", len(header)) + header + body
|
||||
|
||||
|
||||
def _range_handler(payload: bytes):
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
rng = request.headers.get("Range")
|
||||
@ -104,6 +134,41 @@ def _overflow_range_handler(payload: bytes, extra: int = 256 * 1024):
|
||||
return handler
|
||||
|
||||
|
||||
def _short_range_handler(payload: bytes, drop: int = 64 * 1024):
|
||||
"""A 206 server that returns fewer bytes than requested for later segments.
|
||||
|
||||
Simulates a server cleanly closing a range connection early. The response
|
||||
is internally consistent (Content-Length matches the short body), so the
|
||||
client sees no error and the segment just ends short, leaving a zero-filled
|
||||
hole in the preallocated file.
|
||||
"""
|
||||
|
||||
async def handler(request: web.Request) -> web.Response:
|
||||
rng = request.headers.get("Range")
|
||||
if rng:
|
||||
spec = rng.split("=", 1)[1]
|
||||
s, _, e = spec.partition("-")
|
||||
start = int(s)
|
||||
end = int(e) if e else len(payload) - 1
|
||||
chunk = payload[start : end + 1]
|
||||
if start > 0 and len(chunk) > drop:
|
||||
chunk = chunk[:-drop] # truncate a non-first segment
|
||||
return web.Response(
|
||||
status=206,
|
||||
body=chunk,
|
||||
headers={
|
||||
"Content-Range": f"bytes {start}-{end}/{len(payload)}",
|
||||
"Accept-Ranges": "bytes",
|
||||
"ETag": PAYLOAD_ETAG,
|
||||
},
|
||||
)
|
||||
return web.Response(
|
||||
status=200, body=payload, headers={"Accept-Ranges": "bytes", "ETag": PAYLOAD_ETAG}
|
||||
)
|
||||
|
||||
return handler
|
||||
|
||||
|
||||
def _unbounded_handler(total: int, chunk: int = 16384):
|
||||
"""A 200 stream with no Content-Length / Accept-Ranges (unknown length)."""
|
||||
|
||||
@ -151,7 +216,7 @@ def _insert(model_id: str, url: str, status: str = DownloadStatus.QUEUED) -> tup
|
||||
|
||||
|
||||
def test_single_stream_download(model_root):
|
||||
payload = _payload(300_000)
|
||||
payload = _safetensors_payload(300_000)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
@ -178,7 +243,7 @@ def test_single_stream_download(model_root):
|
||||
|
||||
|
||||
def test_segmented_download(model_root):
|
||||
payload = _payload(4 * 1024 * 1024) # 4 MiB -> multiple segments
|
||||
payload = _safetensors_payload(4 * 1024 * 1024) # 4 MiB -> multiple segments
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
@ -206,7 +271,7 @@ def test_segmented_download(model_root):
|
||||
|
||||
|
||||
def test_resume_from_partial(model_root):
|
||||
payload = _payload(512 * 1024) # < 1 MiB -> single segment, but ranges work
|
||||
payload = _safetensors_payload(512 * 1024) # < 1 MiB -> single segment
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
@ -298,6 +363,67 @@ def test_segment_overflow_aborts(model_root):
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_short_segment_fails_closed(model_root):
|
||||
"""A segment that ends short must fail, not be accepted as complete.
|
||||
|
||||
The file is preallocated to total_bytes, so the on-disk size still equals
|
||||
total even with a zero-filled hole; completeness must be judged per-segment.
|
||||
"""
|
||||
payload = _safetensors_payload(4 * 1024 * 1024) # large enough to segment
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_short_range_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/short.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/short.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED, queries.get_download(did).error
|
||||
assert "incomplete" in (queries.get_download(did).error or "")
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_structural_validation_rejects_corrupt(model_root):
|
||||
"""A correctly sized but structurally invalid file fails closed (not retried).
|
||||
|
||||
Regression for the dead structural gate: validation must key off the
|
||||
destination extension, not the ``.part`` temp suffix.
|
||||
"""
|
||||
payload = _payload(300_000) # right size, but not a valid safetensors blob
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
runner, port = await _serve(_noranges_handler(payload))
|
||||
try:
|
||||
url = f"http://127.0.0.1:{port}/model.safetensors"
|
||||
did, final_path, temp = _insert("loras/corrupt.safetensors", url)
|
||||
job = DownloadJob(JobSpec(
|
||||
download_id=did, url=url, model_id="loras/corrupt.safetensors",
|
||||
dest_path=final_path, temp_path=temp,
|
||||
))
|
||||
status = await job.run()
|
||||
assert status == DownloadStatus.FAILED, queries.get_download(did).error
|
||||
assert not os.path.exists(final_path)
|
||||
assert not os.path.exists(temp)
|
||||
# Failed closed at first attempt, not re-queued as retryable.
|
||||
assert queries.get_download(did).attempts == 0
|
||||
finally:
|
||||
await runner.cleanup()
|
||||
await close_session()
|
||||
|
||||
asyncio.run(_run())
|
||||
|
||||
|
||||
def test_rejects_oversized_known_download(model_root, monkeypatch):
|
||||
"""A file whose advertised size exceeds the cap is rejected at probe."""
|
||||
monkeypatch.setattr(args, "download_max_bytes", 100_000, raising=False)
|
||||
@ -353,7 +479,7 @@ def test_unknown_length_capped_by_max_bytes(model_root, monkeypatch):
|
||||
|
||||
|
||||
def test_manager_enqueue_to_completion(model_root):
|
||||
payload = _payload(2 * 1024 * 1024)
|
||||
payload = _safetensors_payload(2 * 1024 * 1024)
|
||||
|
||||
async def _run():
|
||||
await close_session()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user