mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 05:31:03 +08:00
72 lines
2.2 KiB
Python
72 lines
2.2 KiB
Python
"""Unit tests for the segment planner and structural safetensors validation."""
|
|
|
|
from __future__ import annotations
|
|
|
|
import json
|
|
import struct
|
|
|
|
import pytest
|
|
|
|
from app.model_downloader.engine.planner import (
|
|
effective_segment_count,
|
|
plan_segments,
|
|
)
|
|
from app.model_downloader.verify import structural
|
|
|
|
|
|
# ----- planner -----
|
|
|
|
|
|
def test_plan_segments_covers_full_range_contiguously():
|
|
total = 1000
|
|
plans = plan_segments(total, 4)
|
|
assert len(plans) == 4
|
|
assert plans[0].start == 0
|
|
assert plans[-1].end == total - 1
|
|
# contiguous, no gaps/overlaps
|
|
for a, b in zip(plans, plans[1:]):
|
|
assert b.start == a.end + 1
|
|
assert sum(p.length for p in plans) == total
|
|
|
|
|
|
def test_effective_segment_count_falls_back_to_single():
|
|
# No range support -> single
|
|
assert effective_segment_count(10_000_000, False, 8) == 1
|
|
# Unknown size -> single
|
|
assert effective_segment_count(None, True, 8) == 1
|
|
# Tiny file -> fewer segments than configured
|
|
assert effective_segment_count(1024, True, 8) == 1
|
|
# Large file with range support -> configured count
|
|
assert effective_segment_count(1_000_000_000, True, 8) == 8
|
|
|
|
|
|
# ----- structural -----
|
|
|
|
|
|
def _make_safetensors(tensor_data_len: int, *, corrupt_size: bool = False) -> bytes:
|
|
header = {"t": {"dtype": "F32", "shape": [tensor_data_len], "data_offsets": [0, tensor_data_len]}}
|
|
header_bytes = json.dumps(header).encode("utf-8")
|
|
body = b"\x00" * tensor_data_len
|
|
if corrupt_size:
|
|
body = body[:-1] # truncate one byte
|
|
return struct.pack("<Q", len(header_bytes)) + header_bytes + body
|
|
|
|
|
|
def test_structural_valid_safetensors(tmp_path):
|
|
p = tmp_path / "ok.safetensors"
|
|
p.write_bytes(_make_safetensors(256))
|
|
structural.validate(str(p)) # no raise
|
|
|
|
|
|
def test_structural_detects_truncation(tmp_path):
|
|
p = tmp_path / "bad.safetensors"
|
|
p.write_bytes(_make_safetensors(256, corrupt_size=True))
|
|
with pytest.raises(structural.StructuralError):
|
|
structural.validate(str(p))
|
|
|
|
|
|
def test_structural_skips_unknown_extension(tmp_path):
|
|
p = tmp_path / "weights.bin"
|
|
p.write_bytes(b"anything")
|
|
structural.validate(str(p)) # no structural check, no raise
|