Consolidate hash functions into single implementation

Extract file open/seek/restore logic into _open_for_hashing context
manager and use a single hash loop in compute_blake3_hash for both
file paths and file objects.

Amp-Thread-ID: https://ampcode.com/threads/T-019ccb05-0db1-7206-8bd9-1c2efb898fef
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr 2026-03-07 17:22:59 -08:00
parent 42edf71854
commit 7f00f48c96

View File

@ -1,7 +1,8 @@
import io import io
import os import os
from contextlib import contextmanager
from dataclasses import dataclass from dataclasses import dataclass
from typing import IO, Any, Callable from typing import IO, Any, Callable, Iterator
from blake3 import blake3 from blake3 import blake3
@ -20,6 +21,29 @@ class HashCheckpoint:
file_size: int = 0 file_size: int = 0
@contextmanager
def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]:
"""Yield (file_object, is_path) with appropriate setup/teardown."""
if hasattr(fp, "read"):
seekable = getattr(fp, "seekable", lambda: False)()
orig_pos = None
if seekable:
try:
orig_pos = fp.tell()
if orig_pos != 0:
fp.seek(0)
except io.UnsupportedOperation:
orig_pos = None
try:
yield fp, False
finally:
if orig_pos is not None:
fp.seek(orig_pos)
else:
with open(os.fspath(fp), "rb") as f:
yield f, True
def compute_blake3_hash( def compute_blake3_hash(
fp: str | IO[bytes], fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK, chunk_size: int = DEFAULT_CHUNK,
@ -42,12 +66,11 @@ def compute_blake3_hash(
(None, checkpoint) on interruption (file paths only), or (None, checkpoint) on interruption (file paths only), or
(None, None) on interruption of a file object (None, None) on interruption of a file object
""" """
if hasattr(fp, "read"): if chunk_size <= 0:
digest = _hash_file_obj(fp, chunk_size, interrupt_check) chunk_size = DEFAULT_CHUNK
return digest, None
with open(os.fspath(fp), "rb") as f: with _open_for_hashing(fp) as (f, is_path):
if checkpoint is not None: if checkpoint is not None and is_path:
f.seek(checkpoint.bytes_processed) f.seek(checkpoint.bytes_processed)
h = checkpoint.hasher h = checkpoint.hasher
bytes_processed = checkpoint.bytes_processed bytes_processed = checkpoint.bytes_processed
@ -55,15 +78,14 @@ def compute_blake3_hash(
h = blake3() h = blake3()
bytes_processed = 0 bytes_processed = 0
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
while True: while True:
if interrupt_check is not None and interrupt_check(): if interrupt_check is not None and interrupt_check():
return None, HashCheckpoint( if is_path:
bytes_processed=bytes_processed, return None, HashCheckpoint(
hasher=h, bytes_processed=bytes_processed,
) hasher=h,
)
return None, None
chunk = f.read(chunk_size) chunk = f.read(chunk_size)
if not chunk: if not chunk:
break break
@ -71,38 +93,3 @@ def compute_blake3_hash(
bytes_processed += len(chunk) bytes_processed += len(chunk)
return h.hexdigest(), None return h.hexdigest(), None
def _hash_file_obj(
file_obj: IO,
chunk_size: int = DEFAULT_CHUNK,
interrupt_check: InterruptCheck | None = None,
) -> str | None:
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
seekable = getattr(file_obj, "seekable", lambda: False)()
orig_pos = None
if seekable:
try:
orig_pos = file_obj.tell()
if orig_pos != 0:
file_obj.seek(0)
except io.UnsupportedOperation:
seekable = False
orig_pos = None
try:
h = blake3()
while True:
if interrupt_check is not None and interrupt_check():
return None
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
if seekable and orig_pos is not None:
file_obj.seek(orig_pos)