Normalize malformed safetensors headers into StructuralError.

This commit is contained in:
Talmaj Marinc 2026-06-30 17:44:19 +02:00
parent 312b282ca8
commit 893ba2ad37

View File

@ -1,4 +1,4 @@
"""Cheap structural validation, no full read (PRD section 8.2).
"""Cheap structural validation, no full read.
For ``.safetensors``/``.sft`` we parse the header (first few KB): it carries
the tensor table and the byte length of the data region. We assert
@ -53,6 +53,9 @@ def _validate_safetensors(path: str) -> None:
except (UnicodeDecodeError, json.JSONDecodeError) as e:
raise StructuralError(f"safetensors header is not valid JSON: {e}") from e
if not isinstance(header, dict):
raise StructuralError("safetensors header is not a JSON object")
data_len = 0
for name, entry in header.items():
if name == "__metadata__":
@ -62,7 +65,18 @@ def _validate_safetensors(path: str) -> None:
offsets = entry["data_offsets"]
if not (isinstance(offsets, list) and len(offsets) == 2):
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
data_len = max(data_len, int(offsets[1]))
begin, end = offsets
# bool is an int subclass; reject it explicitly to avoid True/False offsets.
if (
not isinstance(begin, int)
or not isinstance(end, int)
or isinstance(begin, bool)
or isinstance(end, bool)
or begin < 0
or end < begin
):
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
data_len = max(data_len, end)
expected = 8 + header_len + data_len
if file_size != expected: