diff --git a/app/model_downloader/engine/writer.py b/app/model_downloader/engine/writer.py index 429e4e197..a7c10dc30 100644 --- a/app/model_downloader/engine/writer.py +++ b/app/model_downloader/engine/writer.py @@ -8,18 +8,25 @@ A single file descriptor is opened for the whole download. Segments write to their own offsets with ``os.pwrite`` — which is offset-addressed and atomic per call, so concurrent segment writers need no extra locking. Per-chunk fsync is avoided; we fsync once at completion. + +``os.pwrite`` is unavailable on Windows, so there we fall back to +``os.lseek`` + ``os.write`` guarded by a per-writer lock (the seek/write pair +is not atomic, so concurrent segment writers must be serialized). """ from __future__ import annotations import asyncio import os +import threading from concurrent.futures import ThreadPoolExecutor from typing import Optional # One shared, bounded pool for all download disk I/O. _EXECUTOR = ThreadPoolExecutor(max_workers=8, thread_name_prefix="dl-writer") +_HAS_PWRITE = hasattr(os, "pwrite") + class FileWriter: """Owns the ``.part`` file descriptor for one download.""" @@ -27,6 +34,8 @@ class FileWriter: def __init__(self, path: str) -> None: self.path = path self._fd: Optional[int] = None + # Serializes lseek+write on platforms without os.pwrite (Windows). + self._seek_lock = threading.Lock() def _open(self) -> None: os.makedirs(os.path.dirname(self.path), exist_ok=True) @@ -52,18 +61,27 @@ class FileWriter: ) def _pwrite_all(self, data: bytes, offset: int) -> None: - """``os.pwrite`` may write fewer bytes than requested (signal + """A positioned write may write fewer bytes than requested (signal interruption, near-ENOSPC); loop until every byte lands so we never - leave a gap while the caller advances by the full chunk length.""" + leave a gap while the caller advances by the full chunk length. + + Uses ``os.pwrite`` where available (offset-addressed, atomic per call). + On Windows it falls back to ``os.lseek`` + ``os.write`` under a lock, + since that pair is not atomic across concurrent segment writers.""" assert self._fd is not None, "writer not opened" view = memoryview(data) written = 0 total = len(view) while written < total: - n = os.pwrite(self._fd, view[written:], offset + written) + if _HAS_PWRITE: + n = os.pwrite(self._fd, view[written:], offset + written) + else: + with self._seek_lock: + os.lseek(self._fd, offset + written, os.SEEK_SET) + n = os.write(self._fd, view[written:]) if n == 0: raise OSError( - f"os.pwrite wrote 0 bytes at offset {offset + written} " + f"positioned write wrote 0 bytes at offset {offset + written} " f"({written}/{total} bytes written)" ) written += n