ComfyUI/tests-unit/seeder_test/test_seeder.py
Luke Mino-Altherr f9d85fa176 Fix race in enqueue_enrich drain: make pending-to-start handoff atomic
Change _lock from Lock to RLock and move the start_enrich call inside the
lock-held block so that enqueue_enrich cannot interleave between clearing
_pending_enrich and starting the enrichment scan. This prevents a concurrent
enqueue_enrich from stealing the IDLE slot and causing the drained payload
to be silently dropped.

Add tests covering:
- pending enrich runs after scan completes
- enqueue during drain does not lose work
- concurrent enqueue during drain is queued for the next cycle

Amp-Thread-ID: https://ampcode.com/threads/T-019cfe02-5710-7506-ae80-34bf16c0171a
Co-authored-by: Amp <amp@ampcode.com>
2026-03-17 19:08:44 -07:00

1084 lines
39 KiB
Python

"""Unit tests for the _AssetSeeder background scanning class."""
import threading
import time
from unittest.mock import patch
import pytest
from app.assets.database.queries.asset_reference import UnenrichedReferenceRow
from app.assets.seeder import _AssetSeeder, Progress, ScanInProgressError, ScanPhase, State
@pytest.fixture
def fresh_seeder():
"""Create a fresh _AssetSeeder instance for testing."""
seeder = _AssetSeeder()
yield seeder
seeder.shutdown(timeout=1.0)
@pytest.fixture
def mock_dependencies():
"""Mock all external dependencies for isolated testing."""
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
yield
class TestSeederStateTransitions:
"""Test state machine transitions."""
def test_initial_state_is_idle(self, fresh_seeder: _AssetSeeder):
assert fresh_seeder.get_status().state == State.IDLE
def test_start_transitions_to_running(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
started = fresh_seeder.start(roots=("models",))
assert started is True
assert reached.wait(timeout=2.0)
assert fresh_seeder.get_status().state == State.RUNNING
barrier.set()
def test_start_while_running_returns_false(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
assert reached.wait(timeout=2.0)
second_start = fresh_seeder.start(roots=("models",))
assert second_start is False
barrier.set()
def test_cancel_transitions_to_cancelling(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
assert reached.wait(timeout=2.0)
cancelled = fresh_seeder.cancel()
assert cancelled is True
assert fresh_seeder.get_status().state == State.CANCELLING
barrier.set()
def test_cancel_when_idle_returns_false(self, fresh_seeder: _AssetSeeder):
cancelled = fresh_seeder.cancel()
assert cancelled is False
def test_state_returns_to_idle_after_completion(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=5.0)
assert completed is True
assert fresh_seeder.get_status().state == State.IDLE
class TestSeederWait:
"""Test wait() behavior."""
def test_wait_blocks_until_complete(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=5.0)
assert completed is True
assert fresh_seeder.get_status().state == State.IDLE
def test_wait_returns_false_on_timeout(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
def slow_collect(*args):
barrier.wait(timeout=10.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
completed = fresh_seeder.wait(timeout=0.1)
assert completed is False
barrier.set()
def test_wait_when_idle_returns_true(self, fresh_seeder: _AssetSeeder):
completed = fresh_seeder.wait(timeout=1.0)
assert completed is True
class TestSeederProgress:
"""Test progress tracking."""
def test_get_status_returns_progress_during_scan(
self, fresh_seeder: _AssetSeeder
):
barrier = threading.Event()
reached = threading.Event()
def slow_build(*args, **kwargs):
reached.set()
barrier.wait(timeout=5.0)
return ([], set(), 0)
paths = ["/path/file1.safetensors", "/path/file2.safetensors"]
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch("app.assets.seeder.build_asset_specs", side_effect=slow_build),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
fresh_seeder.start(roots=("models",))
assert reached.wait(timeout=2.0)
status = fresh_seeder.get_status()
assert status.state == State.RUNNING
assert status.progress is not None
assert status.progress.total == 2
barrier.set()
def test_progress_callback_is_invoked(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
progress_updates: list[Progress] = []
def callback(p: Progress):
progress_updates.append(p)
with patch(
"app.assets.seeder.collect_paths_for_roots",
return_value=[f"/path/file{i}.safetensors" for i in range(10)],
):
fresh_seeder.start(roots=("models",), progress_callback=callback)
fresh_seeder.wait(timeout=5.0)
assert len(progress_updates) > 0
class TestSeederCancellation:
"""Test cancellation behavior."""
def test_scan_commits_partial_progress_on_cancellation(
self, fresh_seeder: _AssetSeeder
):
insert_count = 0
barrier = threading.Event()
first_insert_done = threading.Event()
def slow_insert(specs, tags):
nonlocal insert_count
insert_count += 1
if insert_count == 1:
first_insert_done.set()
if insert_count >= 2:
barrier.wait(timeout=5.0)
return len(specs)
paths = [f"/path/file{i}.safetensors" for i in range(1500)]
specs = [
{
"abs_path": p,
"size_bytes": 100,
"mtime_ns": 0,
"info_name": f"file{i}",
"tags": [],
"fname": f"file{i}",
"metadata": None,
"hash": None,
"mime_type": None,
}
for i, p in enumerate(paths)
]
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
patch(
"app.assets.seeder.build_asset_specs", return_value=(specs, set(), 0)
),
patch("app.assets.seeder.insert_asset_specs", side_effect=slow_insert),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
fresh_seeder.start(roots=("models",))
assert first_insert_done.wait(timeout=2.0)
fresh_seeder.cancel()
barrier.set()
fresh_seeder.wait(timeout=5.0)
assert 1 <= insert_count < 3 # 1500 paths / 500 batch = 3; cancel stopped early
class TestSeederErrorHandling:
"""Test error handling behavior."""
def test_database_errors_captured_in_status(self, fresh_seeder: _AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch(
"app.assets.seeder.collect_paths_for_roots",
return_value=["/path/file.safetensors"],
),
patch(
"app.assets.seeder.build_asset_specs",
return_value=(
[
{
"abs_path": "/path/file.safetensors",
"size_bytes": 100,
"mtime_ns": 0,
"info_name": "file",
"tags": [],
"fname": "file",
"metadata": None,
"hash": None,
"mime_type": None,
}
],
set(),
0,
),
),
patch(
"app.assets.seeder.insert_asset_specs",
side_effect=Exception("DB connection failed"),
),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
fresh_seeder.start(roots=("models",))
fresh_seeder.wait(timeout=5.0)
status = fresh_seeder.get_status()
assert len(status.errors) > 0
assert "DB connection failed" in status.errors[0]
def test_dependencies_unavailable_captured_in_errors(
self, fresh_seeder: _AssetSeeder
):
with patch("app.assets.seeder.dependencies_available", return_value=False):
fresh_seeder.start(roots=("models",))
fresh_seeder.wait(timeout=5.0)
status = fresh_seeder.get_status()
assert len(status.errors) > 0
assert "dependencies" in status.errors[0].lower()
def test_thread_crash_resets_state_to_idle(self, fresh_seeder: _AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch(
"app.assets.seeder.sync_root_safely",
side_effect=RuntimeError("Unexpected crash"),
),
):
fresh_seeder.start(roots=("models",))
fresh_seeder.wait(timeout=5.0)
status = fresh_seeder.get_status()
assert status.state == State.IDLE
assert len(status.errors) > 0
class TestSeederThreadSafety:
"""Test thread safety of concurrent operations."""
def test_concurrent_start_calls_spawn_only_one_thread(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
def slow_collect(*args):
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
results = []
def try_start():
results.append(fresh_seeder.start(roots=("models",)))
threads = [threading.Thread(target=try_start) for _ in range(10)]
for t in threads:
t.start()
for t in threads:
t.join()
barrier.set()
assert sum(results) == 1
def test_get_status_safe_during_scan(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
assert reached.wait(timeout=2.0)
statuses = []
for _ in range(100):
statuses.append(fresh_seeder.get_status())
barrier.set()
assert all(
s.state in (State.RUNNING, State.IDLE, State.CANCELLING)
for s in statuses
)
class TestSeederMarkMissing:
"""Test mark_missing_outside_prefixes behavior."""
def test_mark_missing_when_idle(self, fresh_seeder: _AssetSeeder):
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch(
"app.assets.seeder.get_all_known_prefixes",
return_value=["/models", "/input", "/output"],
),
patch(
"app.assets.seeder.mark_missing_outside_prefixes_safely", return_value=5
) as mock_mark,
):
result = fresh_seeder.mark_missing_outside_prefixes()
assert result == 5
mock_mark.assert_called_once_with(["/models", "/input", "/output"])
def test_mark_missing_raises_when_running(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
assert reached.wait(timeout=2.0)
with pytest.raises(ScanInProgressError):
fresh_seeder.mark_missing_outside_prefixes()
barrier.set()
def test_mark_missing_returns_zero_when_dependencies_unavailable(
self, fresh_seeder: _AssetSeeder
):
with patch("app.assets.seeder.dependencies_available", return_value=False):
result = fresh_seeder.mark_missing_outside_prefixes()
assert result == 0
def test_prune_first_flag_triggers_mark_missing_before_scan(
self, fresh_seeder: _AssetSeeder
):
call_order = []
def track_mark(prefixes):
call_order.append("mark_missing")
return 3
def track_sync(root):
call_order.append(f"sync_{root}")
return set()
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.get_all_known_prefixes", return_value=["/models"]),
patch("app.assets.seeder.mark_missing_outside_prefixes_safely", side_effect=track_mark),
patch("app.assets.seeder.sync_root_safely", side_effect=track_sync),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
fresh_seeder.start(roots=("models",), prune_first=True)
fresh_seeder.wait(timeout=5.0)
assert call_order[0] == "mark_missing"
assert "sync_models" in call_order
class TestSeederPhases:
"""Test phased scanning behavior."""
def test_start_fast_only_runs_fast_phase(self, fresh_seeder: _AssetSeeder):
"""Verify start_fast only runs the fast phase."""
fast_called = []
enrich_called = []
def track_fast(*args, **kwargs):
fast_called.append(True)
return ([], set(), 0)
def track_enrich(*args, **kwargs):
enrich_called.append(True)
return []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
fresh_seeder.start_fast(roots=("models",))
fresh_seeder.wait(timeout=5.0)
assert len(fast_called) == 1
assert len(enrich_called) == 0
def test_start_enrich_only_runs_enrich_phase(self, fresh_seeder: _AssetSeeder):
"""Verify start_enrich only runs the enrich phase."""
fast_called = []
enrich_called = []
def track_fast(*args, **kwargs):
fast_called.append(True)
return ([], set(), 0)
def track_enrich(*args, **kwargs):
enrich_called.append(True)
return []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
fresh_seeder.start_enrich(roots=("models",))
fresh_seeder.wait(timeout=5.0)
assert len(fast_called) == 0
assert len(enrich_called) == 1
def test_full_scan_runs_both_phases(self, fresh_seeder: _AssetSeeder):
"""Verify full scan runs both fast and enrich phases."""
fast_called = []
enrich_called = []
def track_fast(*args, **kwargs):
fast_called.append(True)
return ([], set(), 0)
def track_enrich(*args, **kwargs):
enrich_called.append(True)
return []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", side_effect=track_fast),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=track_enrich),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.FULL)
fresh_seeder.wait(timeout=5.0)
assert len(fast_called) == 1
assert len(enrich_called) == 1
class TestSeederPauseResume:
"""Test pause/resume behavior."""
def test_pause_transitions_to_paused(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
assert reached.wait(timeout=2.0)
paused = fresh_seeder.pause()
assert paused is True
assert fresh_seeder.get_status().state == State.PAUSED
barrier.set()
def test_pause_when_idle_returns_false(self, fresh_seeder: _AssetSeeder):
paused = fresh_seeder.pause()
assert paused is False
def test_resume_returns_to_running(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
assert reached.wait(timeout=2.0)
fresh_seeder.pause()
assert fresh_seeder.get_status().state == State.PAUSED
resumed = fresh_seeder.resume()
assert resumed is True
assert fresh_seeder.get_status().state == State.RUNNING
barrier.set()
def test_resume_when_not_paused_returns_false(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
assert reached.wait(timeout=2.0)
resumed = fresh_seeder.resume()
assert resumed is False
barrier.set()
def test_cancel_while_paused_works(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached_checkpoint = threading.Event()
def slow_collect(*args):
reached_checkpoint.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
assert reached_checkpoint.wait(timeout=2.0)
fresh_seeder.pause()
assert fresh_seeder.get_status().state == State.PAUSED
cancelled = fresh_seeder.cancel()
assert cancelled is True
barrier.set()
fresh_seeder.wait(timeout=5.0)
assert fresh_seeder.get_status().state == State.IDLE
class TestSeederStopRestart:
"""Test stop and restart behavior."""
def test_stop_is_alias_for_cancel(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
assert reached.wait(timeout=2.0)
stopped = fresh_seeder.stop()
assert stopped is True
assert fresh_seeder.get_status().state == State.CANCELLING
barrier.set()
def test_restart_cancels_and_starts_new_scan(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
barrier = threading.Event()
reached = threading.Event()
start_count = 0
def slow_collect(*args):
nonlocal start_count
start_count += 1
if start_count == 1:
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",))
assert reached.wait(timeout=2.0)
barrier.set()
restarted = fresh_seeder.restart()
assert restarted is True
fresh_seeder.wait(timeout=5.0)
assert start_count == 2
def test_restart_preserves_previous_params(self, fresh_seeder: _AssetSeeder):
"""Verify restart uses previous params when not overridden."""
collected_roots = []
def track_collect(roots):
collected_roots.append(roots)
return []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
fresh_seeder.start(roots=("input", "output"))
fresh_seeder.wait(timeout=5.0)
fresh_seeder.restart()
fresh_seeder.wait(timeout=5.0)
assert len(collected_roots) == 2
assert collected_roots[0] == ("input", "output")
assert collected_roots[1] == ("input", "output")
def test_restart_can_override_params(self, fresh_seeder: _AssetSeeder):
"""Verify restart can override previous params."""
collected_roots = []
def track_collect(roots):
collected_roots.append(roots)
return []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
fresh_seeder.start(roots=("models",))
fresh_seeder.wait(timeout=5.0)
fresh_seeder.restart(roots=("input",))
fresh_seeder.wait(timeout=5.0)
assert len(collected_roots) == 2
assert collected_roots[0] == ("models",)
assert collected_roots[1] == ("input",)
class TestEnqueueEnrichHandoff:
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
def test_pending_enrich_runs_after_scan_completes(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
"""A queued enrich request runs automatically when a scan finishes."""
enrich_roots_seen: list[tuple] = []
original_start = fresh_seeder.start
def tracking_start(*args, **kwargs):
phase = kwargs.get("phase")
roots = kwargs.get("roots", args[0] if args else None)
result = original_start(*args, **kwargs)
if phase == ScanPhase.ENRICH and result:
enrich_roots_seen.append(roots)
return result
fresh_seeder.start = tracking_start
# Start a fast scan, then enqueue an enrich while it's running
barrier = threading.Event()
reached = threading.Event()
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert reached.wait(timeout=2.0)
queued = fresh_seeder.enqueue_enrich(
roots=("input",), compute_hashes=True
)
assert queued is False # queued, not started immediately
barrier.set()
# Wait for the original scan + the auto-started enrich scan
deadline = time.monotonic() + 5.0
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
time.sleep(0.05)
assert enrich_roots_seen == [("input",)]
def test_enqueue_enrich_during_drain_does_not_lose_work(
self, fresh_seeder: _AssetSeeder, mock_dependencies
):
"""enqueue_enrich called concurrently with drain cannot drop work.
Simulates the race: another thread calls enqueue_enrich right as the
scan thread is draining _pending_enrich. The enqueue must either be
picked up by the draining scan or successfully start its own scan.
"""
barrier = threading.Event()
reached = threading.Event()
enrich_started = threading.Event()
enrich_call_count = 0
def slow_collect(*args):
reached.set()
barrier.wait(timeout=5.0)
return []
# Track how many times start_enrich actually fires
real_start_enrich = fresh_seeder.start_enrich
enrich_roots_seen: list[tuple] = []
def tracking_start_enrich(**kwargs):
nonlocal enrich_call_count
enrich_call_count += 1
enrich_roots_seen.append(kwargs.get("roots"))
result = real_start_enrich(**kwargs)
if result:
enrich_started.set()
return result
fresh_seeder.start_enrich = tracking_start_enrich
with patch(
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
):
# Start a scan
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert reached.wait(timeout=2.0)
# Queue an enrich while scan is running
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
# Let scan finish — drain will fire start_enrich atomically
barrier.set()
# Wait for drain to complete and the enrich scan to start
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
assert ("output",) in enrich_roots_seen
def test_concurrent_enqueue_during_drain_not_lost(
self, fresh_seeder: _AssetSeeder,
):
"""A second enqueue_enrich arriving while drain is in progress is not lost.
Because the drain now holds _lock through the start_enrich call,
a concurrent enqueue_enrich will block until start_enrich has
transitioned state to RUNNING, then the enqueue will queue its
payload as _pending_enrich for the *next* drain.
"""
scan_barrier = threading.Event()
scan_reached = threading.Event()
enrich_barrier = threading.Event()
enrich_reached = threading.Event()
collect_call = 0
def gated_collect(*args):
nonlocal collect_call
collect_call += 1
if collect_call == 1:
# First call: the initial fast scan
scan_reached.set()
scan_barrier.wait(timeout=5.0)
return []
enrich_call = 0
def gated_get_unenriched(*args, **kwargs):
nonlocal enrich_call
enrich_call += 1
if enrich_call == 1:
# First enrich batch: signal and block
enrich_reached.set()
enrich_barrier.wait(timeout=5.0)
return []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
):
# 1. Start fast scan
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
assert scan_reached.wait(timeout=2.0)
# 2. Queue enrich while fast scan is running
queued = fresh_seeder.enqueue_enrich(
roots=("input",), compute_hashes=False
)
assert queued is False
# 3. Let the fast scan finish — drain will start the enrich scan
scan_barrier.set()
# 4. Wait until the drained enrich scan is running
assert enrich_reached.wait(timeout=5.0)
# 5. Now enqueue another enrich while the drained scan is running
queued2 = fresh_seeder.enqueue_enrich(
roots=("output",), compute_hashes=True
)
assert queued2 is False # should be queued, not started
# Verify _pending_enrich was set (the second enqueue was captured)
with fresh_seeder._lock:
assert fresh_seeder._pending_enrich is not None
assert "output" in fresh_seeder._pending_enrich["roots"]
# Let the enrich scan finish
enrich_barrier.set()
deadline = time.monotonic() + 5.0
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
time.sleep(0.05)
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
return UnenrichedReferenceRow(
reference_id=ref_id, asset_id=asset_id,
file_path=f"/fake/{ref_id}.bin", enrichment_level=0,
)
class TestEnrichPhaseDefensiveLogic:
"""Test skip_ids filtering and consecutive_empty termination."""
def test_failed_refs_are_skipped_on_subsequent_batches(
self, fresh_seeder: _AssetSeeder,
):
"""References that fail enrichment are filtered out of future batches."""
row_a = _make_row("r1")
row_b = _make_row("r2")
call_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count <= 2:
return [row_a, row_b]
return []
enriched_refs: list[list[str]] = []
def fake_enrich(rows, **kwargs):
ref_ids = [r.reference_id for r in rows]
enriched_refs.append(ref_ids)
# r1 always fails, r2 succeeds
failed = [r.reference_id for r in rows if r.reference_id == "r1"]
enriched = len(rows) - len(failed)
return enriched, failed
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# First batch: both refs attempted
assert "r1" in enriched_refs[0]
assert "r2" in enriched_refs[0]
# Second batch: r1 filtered out
assert "r1" not in enriched_refs[1]
assert "r2" in enriched_refs[1]
def test_stops_after_consecutive_empty_batches(
self, fresh_seeder: _AssetSeeder,
):
"""Enrich phase terminates after 3 consecutive batches with zero progress."""
row = _make_row("r1")
batch_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal batch_count
batch_count += 1
# Always return the same row (simulating a permanently failing ref)
return [row]
def fake_enrich(rows, **kwargs):
# Always fail — zero enriched, all failed
return 0, [r.reference_id for r in rows]
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# Should stop after exactly 3 consecutive empty batches
# Batch 1: returns row, enrich fails → filtered out in batch 2+
# But get_unenriched keeps returning it, filter removes it → empty → break
# Actually: batch 1 has row, fails. Batch 2 get_unenriched returns [row],
# skip_ids filters it → empty list → breaks via `if not unenriched: break`
# So it terminates in 2 calls to get_unenriched.
assert batch_count == 2
def test_consecutive_empty_counter_resets_on_success(
self, fresh_seeder: _AssetSeeder,
):
"""A successful batch resets the consecutive empty counter."""
call_count = 0
def fake_get_unenriched(*args, **kwargs):
nonlocal call_count
call_count += 1
if call_count <= 6:
return [_make_row(f"r{call_count}", f"a{call_count}")]
return []
def fake_enrich(rows, **kwargs):
ref_id = rows[0].reference_id
# Fail batches 1-2, succeed batch 3, fail batches 4-5, succeed batch 6
if ref_id in ("r1", "r2", "r4", "r5"):
return 0, [ref_id]
return 1, []
with (
patch("app.assets.seeder.dependencies_available", return_value=True),
patch("app.assets.seeder.sync_root_safely", return_value=set()),
patch("app.assets.seeder.collect_paths_for_roots", return_value=[]),
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
patch("app.assets.seeder.insert_asset_specs", return_value=0),
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=fake_get_unenriched),
patch("app.assets.seeder.enrich_assets_batch", side_effect=fake_enrich),
):
fresh_seeder.start(roots=("models",), phase=ScanPhase.ENRICH)
fresh_seeder.wait(timeout=5.0)
# All 6 batches should run + 1 final call returning empty
assert call_count == 7
status = fresh_seeder.get_status()
assert status.state == State.IDLE