mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-05 17:27:42 +08:00
Add pause/resume/stop/restart controls to AssetSeeder
- Add PAUSED state to state machine - Add pause() method - blocks scan at next checkpoint - Add resume() method - unblocks paused scan - Add stop() method - alias for cancel() - Add restart() method - cancel + wait + start with same/overridden params - Add _check_pause_and_cancel() helper for checkpoint locations - Emit assets.seed.paused and assets.seed.resumed WebSocket events - Update get_object_info to use async seeder instead of blocking seed_assets - Scan all roots (models, input, output) on object_info, not just models Amp-Thread-ID: https://ampcode.com/threads/T-019c4f2b-5801-711c-8d47-bd1525808d77 Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
c7368205e3
commit
b89e5de40e
@ -30,6 +30,7 @@ class State(Enum):
|
|||||||
|
|
||||||
IDLE = "IDLE"
|
IDLE = "IDLE"
|
||||||
RUNNING = "RUNNING"
|
RUNNING = "RUNNING"
|
||||||
|
PAUSED = "PAUSED"
|
||||||
CANCELLING = "CANCELLING"
|
CANCELLING = "CANCELLING"
|
||||||
|
|
||||||
|
|
||||||
@ -90,6 +91,8 @@ class AssetSeeder:
|
|||||||
self._errors: list[str] = []
|
self._errors: list[str] = []
|
||||||
self._thread: threading.Thread | None = None
|
self._thread: threading.Thread | None = None
|
||||||
self._cancel_event = threading.Event()
|
self._cancel_event = threading.Event()
|
||||||
|
self._pause_event = threading.Event()
|
||||||
|
self._pause_event.set() # Start unpaused (set = running, clear = paused)
|
||||||
self._roots: tuple[RootType, ...] = ()
|
self._roots: tuple[RootType, ...] = ()
|
||||||
self._phase: ScanPhase = ScanPhase.FULL
|
self._phase: ScanPhase = ScanPhase.FULL
|
||||||
self._compute_hashes: bool = False
|
self._compute_hashes: bool = False
|
||||||
@ -127,6 +130,7 @@ class AssetSeeder:
|
|||||||
self._compute_hashes = compute_hashes
|
self._compute_hashes = compute_hashes
|
||||||
self._progress_callback = progress_callback
|
self._progress_callback = progress_callback
|
||||||
self._cancel_event.clear()
|
self._cancel_event.clear()
|
||||||
|
self._pause_event.set() # Ensure unpaused when starting
|
||||||
self._thread = threading.Thread(
|
self._thread = threading.Thread(
|
||||||
target=self._run_scan,
|
target=self._run_scan,
|
||||||
name="AssetSeeder",
|
name="AssetSeeder",
|
||||||
@ -187,15 +191,94 @@ class AssetSeeder:
|
|||||||
"""Request cancellation of the current scan.
|
"""Request cancellation of the current scan.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
True if cancellation was requested, False if not running
|
True if cancellation was requested, False if not running or paused
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._state not in (State.RUNNING, State.PAUSED):
|
||||||
|
return False
|
||||||
|
self._state = State.CANCELLING
|
||||||
|
self._cancel_event.set()
|
||||||
|
self._pause_event.set() # Unblock if paused so thread can exit
|
||||||
|
return True
|
||||||
|
|
||||||
|
def stop(self) -> bool:
|
||||||
|
"""Stop the current scan (alias for cancel).
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if stop was requested, False if not running
|
||||||
|
"""
|
||||||
|
return self.cancel()
|
||||||
|
|
||||||
|
def pause(self) -> bool:
|
||||||
|
"""Pause the current scan.
|
||||||
|
|
||||||
|
The scan will complete its current batch before pausing.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if pause was requested, False if not running
|
||||||
"""
|
"""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if self._state != State.RUNNING:
|
if self._state != State.RUNNING:
|
||||||
return False
|
return False
|
||||||
self._state = State.CANCELLING
|
self._state = State.PAUSED
|
||||||
self._cancel_event.set()
|
self._pause_event.clear()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def resume(self) -> bool:
|
||||||
|
"""Resume a paused scan.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if resumed, False if not paused
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self._state != State.PAUSED:
|
||||||
|
return False
|
||||||
|
self._state = State.RUNNING
|
||||||
|
self._pause_event.set()
|
||||||
|
self._emit_event("assets.seed.resumed", {})
|
||||||
|
return True
|
||||||
|
|
||||||
|
def restart(
|
||||||
|
self,
|
||||||
|
roots: tuple[RootType, ...] | None = None,
|
||||||
|
phase: ScanPhase | None = None,
|
||||||
|
progress_callback: ProgressCallback | None = None,
|
||||||
|
prune_first: bool | None = None,
|
||||||
|
compute_hashes: bool | None = None,
|
||||||
|
timeout: float = 5.0,
|
||||||
|
) -> bool:
|
||||||
|
"""Cancel any running scan and start a new one.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots: Roots to scan (defaults to previous roots)
|
||||||
|
phase: Scan phase (defaults to previous phase)
|
||||||
|
progress_callback: Progress callback (defaults to previous)
|
||||||
|
prune_first: Prune before scan (defaults to previous)
|
||||||
|
compute_hashes: Compute hashes (defaults to previous)
|
||||||
|
timeout: Max seconds to wait for current scan to stop
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if new scan was started, False if failed to stop previous
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
prev_roots = self._roots
|
||||||
|
prev_phase = self._phase
|
||||||
|
prev_callback = self._progress_callback
|
||||||
|
prev_prune = getattr(self, "_prune_first", False)
|
||||||
|
prev_hashes = self._compute_hashes
|
||||||
|
|
||||||
|
self.cancel()
|
||||||
|
if not self.wait(timeout=timeout):
|
||||||
|
return False
|
||||||
|
|
||||||
|
return self.start(
|
||||||
|
roots=roots if roots is not None else prev_roots,
|
||||||
|
phase=phase if phase is not None else prev_phase,
|
||||||
|
progress_callback=progress_callback if progress_callback is not None else prev_callback,
|
||||||
|
prune_first=prune_first if prune_first is not None else prev_prune,
|
||||||
|
compute_hashes=compute_hashes if compute_hashes is not None else prev_hashes,
|
||||||
|
)
|
||||||
|
|
||||||
def wait(self, timeout: float | None = None) -> bool:
|
def wait(self, timeout: float | None = None) -> bool:
|
||||||
"""Wait for the current scan to complete.
|
"""Wait for the current scan to complete.
|
||||||
|
|
||||||
@ -284,6 +367,21 @@ class AssetSeeder:
|
|||||||
"""Check if cancellation has been requested."""
|
"""Check if cancellation has been requested."""
|
||||||
return self._cancel_event.is_set()
|
return self._cancel_event.is_set()
|
||||||
|
|
||||||
|
def _check_pause_and_cancel(self) -> bool:
|
||||||
|
"""Block while paused, then check if cancelled.
|
||||||
|
|
||||||
|
Call this at checkpoint locations in scan loops. It will:
|
||||||
|
1. Block indefinitely while paused (until resume or cancel)
|
||||||
|
2. Return True if cancelled, False to continue
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if scan should stop, False to continue
|
||||||
|
"""
|
||||||
|
if not self._pause_event.is_set():
|
||||||
|
self._emit_event("assets.seed.paused", {})
|
||||||
|
self._pause_event.wait() # Blocks if paused
|
||||||
|
return self._is_cancelled()
|
||||||
|
|
||||||
def _emit_event(self, event_type: str, data: dict) -> None:
|
def _emit_event(self, event_type: str, data: dict) -> None:
|
||||||
"""Emit a WebSocket event if server is available."""
|
"""Emit a WebSocket event if server is available."""
|
||||||
try:
|
try:
|
||||||
@ -377,7 +475,7 @@ class AssetSeeder:
|
|||||||
if marked > 0:
|
if marked > 0:
|
||||||
logging.info("Marked %d cache states as missing before scan", marked)
|
logging.info("Marked %d cache states as missing before scan", marked)
|
||||||
|
|
||||||
if self._is_cancelled():
|
if self._check_pause_and_cancel():
|
||||||
logging.info("Asset scan cancelled after pruning phase")
|
logging.info("Asset scan cancelled after pruning phase")
|
||||||
cancelled = True
|
cancelled = True
|
||||||
return
|
return
|
||||||
@ -388,7 +486,7 @@ class AssetSeeder:
|
|||||||
if phase in (ScanPhase.FAST, ScanPhase.FULL):
|
if phase in (ScanPhase.FAST, ScanPhase.FULL):
|
||||||
total_created, skipped_existing, total_paths = self._run_fast_phase(roots)
|
total_created, skipped_existing, total_paths = self._run_fast_phase(roots)
|
||||||
|
|
||||||
if self._is_cancelled():
|
if self._check_pause_and_cancel():
|
||||||
cancelled = True
|
cancelled = True
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -404,7 +502,7 @@ class AssetSeeder:
|
|||||||
|
|
||||||
# Phase 2: Enrichment scan (metadata + hashes)
|
# Phase 2: Enrichment scan (metadata + hashes)
|
||||||
if phase in (ScanPhase.ENRICH, ScanPhase.FULL):
|
if phase in (ScanPhase.ENRICH, ScanPhase.FULL):
|
||||||
if self._is_cancelled():
|
if self._check_pause_and_cancel():
|
||||||
cancelled = True
|
cancelled = True
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -469,11 +567,11 @@ class AssetSeeder:
|
|||||||
|
|
||||||
existing_paths: set[str] = set()
|
existing_paths: set[str] = set()
|
||||||
for r in roots:
|
for r in roots:
|
||||||
if self._is_cancelled():
|
if self._check_pause_and_cancel():
|
||||||
return total_created, skipped_existing, 0
|
return total_created, skipped_existing, 0
|
||||||
existing_paths.update(sync_root_safely(r))
|
existing_paths.update(sync_root_safely(r))
|
||||||
|
|
||||||
if self._is_cancelled():
|
if self._check_pause_and_cancel():
|
||||||
return total_created, skipped_existing, 0
|
return total_created, skipped_existing, 0
|
||||||
|
|
||||||
paths = collect_paths_for_roots(roots)
|
paths = collect_paths_for_roots(roots)
|
||||||
@ -489,7 +587,7 @@ class AssetSeeder:
|
|||||||
specs, tag_pool, skipped_existing = build_stub_specs(paths, existing_paths)
|
specs, tag_pool, skipped_existing = build_stub_specs(paths, existing_paths)
|
||||||
self._update_progress(skipped=skipped_existing)
|
self._update_progress(skipped=skipped_existing)
|
||||||
|
|
||||||
if self._is_cancelled():
|
if self._check_pause_and_cancel():
|
||||||
return total_created, skipped_existing, total_paths
|
return total_created, skipped_existing, total_paths
|
||||||
|
|
||||||
batch_size = 500
|
batch_size = 500
|
||||||
@ -497,7 +595,7 @@ class AssetSeeder:
|
|||||||
progress_interval = 1.0
|
progress_interval = 1.0
|
||||||
|
|
||||||
for i in range(0, len(specs), batch_size):
|
for i in range(0, len(specs), batch_size):
|
||||||
if self._is_cancelled():
|
if self._check_pause_and_cancel():
|
||||||
logging.info(
|
logging.info(
|
||||||
"Fast scan cancelled after %d/%d files (created=%d)",
|
"Fast scan cancelled after %d/%d files (created=%d)",
|
||||||
i,
|
i,
|
||||||
@ -554,7 +652,7 @@ class AssetSeeder:
|
|||||||
)
|
)
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if self._is_cancelled():
|
if self._check_pause_and_cancel():
|
||||||
logging.info("Enrich scan cancelled after %d assets", total_enriched)
|
logging.info("Enrich scan cancelled after %d assets", total_enriched)
|
||||||
break
|
break
|
||||||
|
|
||||||
|
|||||||
@ -33,7 +33,7 @@ import node_helpers
|
|||||||
from comfyui_version import __version__
|
from comfyui_version import __version__
|
||||||
from app.frontend_management import FrontendManager, parse_version
|
from app.frontend_management import FrontendManager, parse_version
|
||||||
from comfy_api.internal import _ComfyNodeInternal
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
from app.assets.scanner import seed_assets
|
from app.assets.seeder import asset_seeder
|
||||||
from app.assets.api.routes import register_assets_system
|
from app.assets.api.routes import register_assets_system
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
@ -697,10 +697,7 @@ class PromptServer():
|
|||||||
|
|
||||||
@routes.get("/object_info")
|
@routes.get("/object_info")
|
||||||
async def get_object_info(request):
|
async def get_object_info(request):
|
||||||
try:
|
asset_seeder.start(roots=("models", "input", "output"))
|
||||||
seed_assets(["models"])
|
|
||||||
except Exception as e:
|
|
||||||
logging.error(f"Failed to seed assets: {e}")
|
|
||||||
with folder_paths.cache_helper:
|
with folder_paths.cache_helper:
|
||||||
out = {}
|
out = {}
|
||||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||||
|
|||||||
@ -523,3 +523,261 @@ class TestSeederPhases:
|
|||||||
|
|
||||||
assert len(fast_called) == 1
|
assert len(fast_called) == 1
|
||||||
assert len(enrich_called) == 1
|
assert len(enrich_called) == 1
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeederPauseResume:
|
||||||
|
"""Test pause/resume behavior."""
|
||||||
|
|
||||||
|
def test_pause_transitions_to_paused(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
paused = fresh_seeder.pause()
|
||||||
|
assert paused is True
|
||||||
|
assert fresh_seeder.get_status().state == State.PAUSED
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_pause_when_idle_returns_false(self, fresh_seeder: AssetSeeder):
|
||||||
|
paused = fresh_seeder.pause()
|
||||||
|
assert paused is False
|
||||||
|
|
||||||
|
def test_resume_returns_to_running(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
fresh_seeder.pause()
|
||||||
|
assert fresh_seeder.get_status().state == State.PAUSED
|
||||||
|
|
||||||
|
resumed = fresh_seeder.resume()
|
||||||
|
assert resumed is True
|
||||||
|
assert fresh_seeder.get_status().state == State.RUNNING
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_resume_when_not_paused_returns_false(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
resumed = fresh_seeder.resume()
|
||||||
|
assert resumed is False
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_cancel_while_paused_works(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached_checkpoint = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached_checkpoint.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
reached_checkpoint.wait(timeout=1.0)
|
||||||
|
|
||||||
|
fresh_seeder.pause()
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
cancelled = fresh_seeder.cancel()
|
||||||
|
assert cancelled is True
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
assert fresh_seeder.get_status().state == State.IDLE
|
||||||
|
|
||||||
|
def test_pause_blocks_scan_until_resume(self, fresh_seeder: AssetSeeder):
|
||||||
|
"""Verify scan blocks at checkpoint while paused."""
|
||||||
|
batch_count = 0
|
||||||
|
pause_detected = threading.Event()
|
||||||
|
resume_signal = threading.Event()
|
||||||
|
|
||||||
|
def counting_insert(specs, tags):
|
||||||
|
nonlocal batch_count
|
||||||
|
batch_count += 1
|
||||||
|
if batch_count == 1:
|
||||||
|
pause_detected.set()
|
||||||
|
resume_signal.wait(timeout=5.0)
|
||||||
|
return len(specs)
|
||||||
|
|
||||||
|
paths = [f"/path/file{i}.safetensors" for i in range(1000)]
|
||||||
|
specs = [
|
||||||
|
{
|
||||||
|
"abs_path": p,
|
||||||
|
"size_bytes": 100,
|
||||||
|
"mtime_ns": 0,
|
||||||
|
"info_name": f"file{i}",
|
||||||
|
"tags": [],
|
||||||
|
"fname": f"file{i}",
|
||||||
|
"metadata": None,
|
||||||
|
"hash": None,
|
||||||
|
"mime_type": None,
|
||||||
|
}
|
||||||
|
for i, p in enumerate(paths)
|
||||||
|
]
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", return_value=paths),
|
||||||
|
patch("app.assets.seeder.build_stub_specs", return_value=(specs, set(), 0)),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", side_effect=counting_insert),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
pause_detected.wait(timeout=2.0)
|
||||||
|
|
||||||
|
fresh_seeder.pause()
|
||||||
|
count_at_pause = batch_count
|
||||||
|
time.sleep(0.1)
|
||||||
|
assert batch_count == count_at_pause
|
||||||
|
|
||||||
|
fresh_seeder.resume()
|
||||||
|
resume_signal.set()
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
assert batch_count > count_at_pause
|
||||||
|
|
||||||
|
|
||||||
|
class TestSeederStopRestart:
|
||||||
|
"""Test stop and restart behavior."""
|
||||||
|
|
||||||
|
def test_stop_is_alias_for_cancel(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
stopped = fresh_seeder.stop()
|
||||||
|
assert stopped is True
|
||||||
|
assert fresh_seeder.get_status().state == State.CANCELLING
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
def test_restart_cancels_and_starts_new_scan(
|
||||||
|
self, fresh_seeder: AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
barrier = threading.Event()
|
||||||
|
start_count = 0
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
nonlocal start_count
|
||||||
|
start_count += 1
|
||||||
|
if start_count == 1:
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
restarted = fresh_seeder.restart()
|
||||||
|
assert restarted is True
|
||||||
|
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
assert start_count == 2
|
||||||
|
|
||||||
|
def test_restart_preserves_previous_params(self, fresh_seeder: AssetSeeder):
|
||||||
|
"""Verify restart uses previous params when not overridden."""
|
||||||
|
collected_roots = []
|
||||||
|
|
||||||
|
def track_collect(roots):
|
||||||
|
collected_roots.append(roots)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
|
||||||
|
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("input", "output"))
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
fresh_seeder.restart()
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(collected_roots) == 2
|
||||||
|
assert collected_roots[0] == ("input", "output")
|
||||||
|
assert collected_roots[1] == ("input", "output")
|
||||||
|
|
||||||
|
def test_restart_can_override_params(self, fresh_seeder: AssetSeeder):
|
||||||
|
"""Verify restart can override previous params."""
|
||||||
|
collected_roots = []
|
||||||
|
|
||||||
|
def track_collect(roots):
|
||||||
|
collected_roots.append(roots)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", side_effect=track_collect),
|
||||||
|
patch("app.assets.seeder.build_stub_specs", return_value=([], set(), 0)),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", return_value=[]),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",))
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
fresh_seeder.restart(roots=("input",))
|
||||||
|
fresh_seeder.wait(timeout=5.0)
|
||||||
|
|
||||||
|
assert len(collected_roots) == 2
|
||||||
|
assert collected_roots[0] == ("models",)
|
||||||
|
assert collected_roots[1] == ("input",)
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user