mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-03 21:20:49 +08:00
Normalize malformed safetensors headers into StructuralError.
This commit is contained in:
parent
312b282ca8
commit
893ba2ad37
@ -1,4 +1,4 @@
|
|||||||
"""Cheap structural validation, no full read (PRD section 8.2).
|
"""Cheap structural validation, no full read.
|
||||||
|
|
||||||
For ``.safetensors``/``.sft`` we parse the header (first few KB): it carries
|
For ``.safetensors``/``.sft`` we parse the header (first few KB): it carries
|
||||||
the tensor table and the byte length of the data region. We assert
|
the tensor table and the byte length of the data region. We assert
|
||||||
@ -53,6 +53,9 @@ def _validate_safetensors(path: str) -> None:
|
|||||||
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
except (UnicodeDecodeError, json.JSONDecodeError) as e:
|
||||||
raise StructuralError(f"safetensors header is not valid JSON: {e}") from e
|
raise StructuralError(f"safetensors header is not valid JSON: {e}") from e
|
||||||
|
|
||||||
|
if not isinstance(header, dict):
|
||||||
|
raise StructuralError("safetensors header is not a JSON object")
|
||||||
|
|
||||||
data_len = 0
|
data_len = 0
|
||||||
for name, entry in header.items():
|
for name, entry in header.items():
|
||||||
if name == "__metadata__":
|
if name == "__metadata__":
|
||||||
@ -62,7 +65,18 @@ def _validate_safetensors(path: str) -> None:
|
|||||||
offsets = entry["data_offsets"]
|
offsets = entry["data_offsets"]
|
||||||
if not (isinstance(offsets, list) and len(offsets) == 2):
|
if not (isinstance(offsets, list) and len(offsets) == 2):
|
||||||
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
|
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
|
||||||
data_len = max(data_len, int(offsets[1]))
|
begin, end = offsets
|
||||||
|
# bool is an int subclass; reject it explicitly to avoid True/False offsets.
|
||||||
|
if (
|
||||||
|
not isinstance(begin, int)
|
||||||
|
or not isinstance(end, int)
|
||||||
|
or isinstance(begin, bool)
|
||||||
|
or isinstance(end, bool)
|
||||||
|
or begin < 0
|
||||||
|
or end < begin
|
||||||
|
):
|
||||||
|
raise StructuralError(f"tensor {name!r} has malformed data_offsets")
|
||||||
|
data_len = max(data_len, end)
|
||||||
|
|
||||||
expected = 8 + header_len + data_len
|
expected = 8 + header_len + data_len
|
||||||
if file_size != expected:
|
if file_size != expected:
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user