fix: update seeder tests
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

Amp-Thread-ID: https://ampcode.com/threads/T-019c9122-ba77-768a-b827-1a4adea1c97e
Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
Luke Mino-Altherr 2026-02-24 11:34:02 -08:00
parent 6436190143
commit 1fa4c88907

View File

@ -1,7 +1,6 @@
"""Unit tests for the AssetSeeder background scanning class.""" """Unit tests for the AssetSeeder background scanning class."""
import threading import threading
import time
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -43,17 +42,32 @@ class TestSeederStateTransitions:
def test_start_transitions_to_running( def test_start_transitions_to_running(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: AssetSeeder, mock_dependencies
): ):
started = fresh_seeder.start(roots=("models",)) barrier = threading.Event()
assert started is True reached = threading.Event()
status = fresh_seeder.get_status()
assert status.state in (State.RUNNING, State.IDLE) def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
started = fresh_seeder.start(roots=("models",))
assert started is True
assert reached.wait(timeout=2.0)
assert fresh_seeder.get_status().state == State.RUNNING
barrier.set()
def test_start_while_running_returns_false( def test_start_while_running_returns_false(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args): def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0) barrier.wait(timeout=5.0)
return [] return []
@ -61,7 +75,7 @@ class TestSeederStateTransitions:
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) assert reached.wait(timeout=2.0)
second_start = fresh_seeder.start(roots=("models",)) second_start = fresh_seeder.start(roots=("models",))
assert second_start is False assert second_start is False
@ -72,8 +86,10 @@ class TestSeederStateTransitions:
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args): def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0) barrier.wait(timeout=5.0)
return [] return []
@ -81,7 +97,7 @@ class TestSeederStateTransitions:
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) assert reached.wait(timeout=2.0)
cancelled = fresh_seeder.cancel() cancelled = fresh_seeder.cancel()
assert cancelled is True assert cancelled is True
@ -140,24 +156,34 @@ class TestSeederProgress:
"""Test progress tracking.""" """Test progress tracking."""
def test_get_status_returns_progress_during_scan( def test_get_status_returns_progress_during_scan(
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: AssetSeeder
): ):
progress_seen = []
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args): def slow_build(*args, **kwargs):
reached.set()
barrier.wait(timeout=5.0) barrier.wait(timeout=5.0)
return ["/path/file1.safetensors", "/path/file2.safetensors"] return ([], set(), 0)
with patch( paths = ["/path/file1.safetensors", "/path/file2.safetensors"]
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch("app.assets.seeder.build_stub_specs", side_effect=slow_build),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) assert reached.wait(timeout=2.0)
status = fresh_seeder.get_status() status = fresh_seeder.get_status()
assert status.state == State.RUNNING
assert status.progress is not None assert status.progress is not None
progress_seen.append(status.progress) assert status.progress.total == 2
barrier.set() barrier.set()
@ -187,10 +213,13 @@ class TestSeederCancellation:
): ):
insert_count = 0 insert_count = 0
barrier = threading.Event() barrier = threading.Event()
first_insert_done = threading.Event()
def slow_insert(specs, tags): def slow_insert(specs, tags):
nonlocal insert_count nonlocal insert_count
insert_count += 1 insert_count += 1
if insert_count == 1:
first_insert_done.set()
if insert_count >= 2: if insert_count >= 2:
barrier.wait(timeout=5.0) barrier.wait(timeout=5.0)
return len(specs) return len(specs)
@ -223,13 +252,13 @@ class TestSeederCancellation:
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)), patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.1) assert first_insert_done.wait(timeout=2.0)
fresh_seeder.cancel() fresh_seeder.cancel()
barrier.set() barrier.set()
fresh_seeder.wait(timeout=5.0) fresh_seeder.wait(timeout=5.0)
assert insert_count >= 1 assert 1 <= insert_count < 3 # 1500 paths / 500 batch = 3; cancel stopped early
class TestSeederErrorHandling: class TestSeederErrorHandling:
@ -338,8 +367,10 @@ class TestSeederThreadSafety:
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args): def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0) barrier.wait(timeout=5.0)
return [] return []
@ -347,11 +378,11 @@ class TestSeederThreadSafety:
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
assert reached.wait(timeout=2.0)
statuses = [] statuses = []
for _ in range(100): for _ in range(100):
statuses.append(fresh_seeder.get_status()) statuses.append(fresh_seeder.get_status())
time.sleep(0.001)
barrier.set() barrier.set()
@ -383,8 +414,10 @@ class TestSeederMarkMissing:
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args): def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0) barrier.wait(timeout=5.0)
return [] return []
@ -392,7 +425,7 @@ class TestSeederMarkMissing:
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) assert reached.wait(timeout=2.0)
result = fresh_seeder.mark_missing_outside_prefixes() result = fresh_seeder.mark_missing_outside_prefixes()
assert result == 0 assert result == 0
@ -532,8 +565,10 @@ class TestSeederPauseResume:
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args): def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0) barrier.wait(timeout=5.0)
return [] return []
@ -541,7 +576,7 @@ class TestSeederPauseResume:
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) assert reached.wait(timeout=2.0)
paused = fresh_seeder.pause() paused = fresh_seeder.pause()
assert paused is True assert paused is True
@ -557,8 +592,10 @@ class TestSeederPauseResume:
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args): def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0) barrier.wait(timeout=5.0)
return [] return []
@ -566,7 +603,7 @@ class TestSeederPauseResume:
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) assert reached.wait(timeout=2.0)
fresh_seeder.pause() fresh_seeder.pause()
assert fresh_seeder.get_status().state == State.PAUSED assert fresh_seeder.get_status().state == State.PAUSED
@ -581,8 +618,10 @@ class TestSeederPauseResume:
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args): def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0) barrier.wait(timeout=5.0)
return [] return []
@ -590,7 +629,7 @@ class TestSeederPauseResume:
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) assert reached.wait(timeout=2.0)
resumed = fresh_seeder.resume() resumed = fresh_seeder.resume()
assert resumed is False assert resumed is False
@ -612,10 +651,10 @@ class TestSeederPauseResume:
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
reached_checkpoint.wait(timeout=1.0) assert reached_checkpoint.wait(timeout=2.0)
fresh_seeder.pause() fresh_seeder.pause()
time.sleep(0.05) assert fresh_seeder.get_status().state == State.PAUSED
cancelled = fresh_seeder.cancel() cancelled = fresh_seeder.cancel()
assert cancelled is True assert cancelled is True
@ -624,60 +663,6 @@ class TestSeederPauseResume:
fresh_seeder.wait(timeout=5.0) fresh_seeder.wait(timeout=5.0)
assert fresh_seeder.get_status().state == State.IDLE assert fresh_seeder.get_status().state == State.IDLE
def test_pause_blocks_scan_until_resume(self, fresh_seeder: AssetSeeder):
"""Verify scan blocks at checkpoint while paused."""
batch_count = 0
pause_detected = threading.Event()
resume_signal = threading.Event()
def counting_insert(specs, tags):
nonlocal batch_count
batch_count += 1
if batch_count == 1:
pause_detected.set()
resume_signal.wait(timeout=5.0)
return len(specs)
paths = [f"/path/file{i}.safetensors" for i in range(1000)]
specs = [
{
"abs_path": p,
"size_bytes": 100,
"mtime_ns": 0,
"info_name": f"file{i}",
"tags": [],
"fname": f"file{i}",
"metadata": None,
"hash": None,
"mime_type": None,
}
for i, p in enumerate(paths)
]
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch("app.assets.seeder.build_stub_specs", return_value=(specs, set(), 0)),
patch("app.assets.seeder.insert_asset_specs", side_effect=counting_insert),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
fresh_seeder.start(roots=("models",))
pause_detected.wait(timeout=2.0)
fresh_seeder.pause()
count_at_pause = batch_count
time.sleep(0.1)
assert batch_count == count_at_pause
fresh_seeder.resume()
resume_signal.set()
fresh_seeder.wait(timeout=5.0)
assert batch_count > count_at_pause
class TestSeederStopRestart: class TestSeederStopRestart:
"""Test stop and restart behavior.""" """Test stop and restart behavior."""
@ -685,8 +670,10 @@ class TestSeederStopRestart:
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args): def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0) barrier.wait(timeout=5.0)
return [] return []
@ -694,7 +681,7 @@ class TestSeederStopRestart:
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) assert reached.wait(timeout=2.0)
stopped = fresh_seeder.stop() stopped = fresh_seeder.stop()
assert stopped is True assert stopped is True
@ -706,12 +693,14 @@ class TestSeederStopRestart:
self, fresh_seeder: AssetSeeder, mock_dependencies self, fresh_seeder: AssetSeeder, mock_dependencies
): ):
barrier = threading.Event() barrier = threading.Event()
reached = threading.Event()
start_count = 0 start_count = 0
def slow_collect(*args): def slow_collect(*args):
nonlocal start_count nonlocal start_count
start_count += 1 start_count += 1
if start_count == 1: if start_count == 1:
reached.set()
barrier.wait(timeout=5.0) barrier.wait(timeout=5.0)
return [] return []
@ -719,7 +708,7 @@ class TestSeederStopRestart:
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect "app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
): ):
fresh_seeder.start(roots=("models",)) fresh_seeder.start(roots=("models",))
time.sleep(0.05) assert reached.wait(timeout=2.0)
barrier.set() barrier.set()
restarted = fresh_seeder.restart() restarted = fresh_seeder.restart()