ComfyUI/app/assets/seeder.py
Luke Mino-Altherr 043a75acde
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Fix ruff linting issues
- Remove debug print statements
- Remove trailing whitespace on blank lines
- Remove unused pytest import

Amp-Thread-ID: https://ampcode.com/threads/T-019c3a8d-3b4f-75b4-8513-1c77914782f7
Co-authored-by: Amp <amp@ampcode.com>
2026-02-10 13:01:25 -08:00

434 lines
14 KiB
Python

"""Background asset seeder with thread management and cancellation support."""
import logging
import os
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Callable
from app.assets.scanner import (
RootType,
build_asset_specs,
collect_paths_for_roots,
get_all_known_prefixes,
get_prefixes_for_root,
insert_asset_specs,
mark_missing_outside_prefixes_safely,
sync_root_safely,
)
from app.database.db import dependencies_available
class State(Enum):
"""Seeder state machine states."""
IDLE = "IDLE"
RUNNING = "RUNNING"
CANCELLING = "CANCELLING"
@dataclass
class Progress:
"""Progress information for a scan operation."""
scanned: int = 0
total: int = 0
created: int = 0
skipped: int = 0
@dataclass
class ScanStatus:
"""Current status of the asset seeder."""
state: State
progress: Progress | None
errors: list[str] = field(default_factory=list)
ProgressCallback = Callable[[Progress], None]
class AssetSeeder:
"""Singleton class managing background asset scanning.
Thread-safe singleton that spawns ephemeral daemon threads for scanning.
Each scan creates a new thread that exits when complete.
"""
_instance: "AssetSeeder | None" = None
_instance_lock = threading.Lock()
def __new__(cls) -> "AssetSeeder":
with cls._instance_lock:
if cls._instance is None:
cls._instance = super().__new__(cls)
cls._instance._initialized = False
return cls._instance
def __init__(self) -> None:
if self._initialized:
return
self._initialized = True
self._lock = threading.Lock()
self._state = State.IDLE
self._progress: Progress | None = None
self._errors: list[str] = []
self._thread: threading.Thread | None = None
self._cancel_event = threading.Event()
self._roots: tuple[RootType, ...] = ()
self._compute_hashes: bool = False
self._progress_callback: ProgressCallback | None = None
def start(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
progress_callback: ProgressCallback | None = None,
prune_first: bool = False,
compute_hashes: bool = False,
) -> bool:
"""Start a background scan for the given roots.
Args:
roots: Tuple of root types to scan (models, input, output)
progress_callback: Optional callback called with progress updates
prune_first: If True, prune orphaned assets before scanning
compute_hashes: If True, compute blake3 hashes for each file (slow for large files)
Returns:
True if scan was started, False if already running
"""
with self._lock:
if self._state != State.IDLE:
return False
self._state = State.RUNNING
self._progress = Progress()
self._errors = []
self._roots = roots
self._prune_first = prune_first
self._compute_hashes = compute_hashes
self._progress_callback = progress_callback
self._cancel_event.clear()
self._thread = threading.Thread(
target=self._run_scan,
name="AssetSeeder",
daemon=True,
)
self._thread.start()
return True
def cancel(self) -> bool:
"""Request cancellation of the current scan.
Returns:
True if cancellation was requested, False if not running
"""
with self._lock:
if self._state != State.RUNNING:
return False
self._state = State.CANCELLING
self._cancel_event.set()
return True
def wait(self, timeout: float | None = None) -> bool:
"""Wait for the current scan to complete.
Args:
timeout: Maximum seconds to wait, or None for no timeout
Returns:
True if scan completed, False if timeout expired or no scan running
"""
with self._lock:
thread = self._thread
if thread is None:
return True
thread.join(timeout=timeout)
return not thread.is_alive()
def get_status(self) -> ScanStatus:
"""Get the current status and progress of the seeder."""
with self._lock:
return ScanStatus(
state=self._state,
progress=Progress(
scanned=self._progress.scanned,
total=self._progress.total,
created=self._progress.created,
skipped=self._progress.skipped,
)
if self._progress
else None,
errors=list(self._errors),
)
def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown: cancel any running scan and wait for thread.
Args:
timeout: Maximum seconds to wait for thread to exit
"""
self.cancel()
self.wait(timeout=timeout)
with self._lock:
self._thread = None
def mark_missing_outside_prefixes(self) -> int:
"""Mark cache states as missing when outside all known root prefixes.
This is a non-destructive soft-delete operation. Assets and their
metadata are preserved, but cache states are flagged as missing.
They can be restored if the file reappears in a future scan.
This operation is decoupled from scanning to prevent partial scans
from accidentally marking assets belonging to other roots.
Should be called explicitly when cleanup is desired, typically after
a full scan of all roots or during maintenance.
Returns:
Number of cache states marked as missing, or 0 if dependencies
unavailable or a scan is currently running
"""
with self._lock:
if self._state != State.IDLE:
logging.warning(
"Cannot mark missing assets while scan is running"
)
return 0
self._state = State.RUNNING
try:
if not dependencies_available():
logging.warning(
"Database dependencies not available, skipping mark missing"
)
return 0
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d cache states as missing", marked)
return marked
finally:
with self._lock:
self._state = State.IDLE
def _is_cancelled(self) -> bool:
"""Check if cancellation has been requested."""
return self._cancel_event.is_set()
def _emit_event(self, event_type: str, data: dict) -> None:
"""Emit a WebSocket event if server is available."""
try:
from server import PromptServer
if hasattr(PromptServer, "instance") and PromptServer.instance:
PromptServer.instance.send_sync(event_type, data)
except Exception:
pass
def _update_progress(
self,
scanned: int | None = None,
total: int | None = None,
created: int | None = None,
skipped: int | None = None,
) -> None:
"""Update progress counters (thread-safe)."""
callback: ProgressCallback | None = None
progress: Progress | None = None
with self._lock:
if self._progress is None:
return
if scanned is not None:
self._progress.scanned = scanned
if total is not None:
self._progress.total = total
if created is not None:
self._progress.created = created
if skipped is not None:
self._progress.skipped = skipped
if self._progress_callback:
callback = self._progress_callback
progress = Progress(
scanned=self._progress.scanned,
total=self._progress.total,
created=self._progress.created,
skipped=self._progress.skipped,
)
if callback and progress:
try:
callback(progress)
except Exception:
pass
def _add_error(self, message: str) -> None:
"""Add an error message (thread-safe)."""
with self._lock:
self._errors.append(message)
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
"""Log the directories that will be scanned."""
import folder_paths
for root in roots:
if root == "models":
logging.info(
"Asset scan [models] directory: %s",
os.path.abspath(folder_paths.models_dir),
)
else:
prefixes = get_prefixes_for_root(root)
if prefixes:
logging.info("Asset scan [%s] directories: %s", root, prefixes)
def _run_scan(self) -> None:
"""Main scan loop running in background thread."""
t_start = time.perf_counter()
roots = self._roots
cancelled = False
total_created = 0
skipped_existing = 0
total_paths = 0
try:
if not dependencies_available():
self._add_error("Database dependencies not available")
self._emit_event(
"assets.seed.error",
{"message": "Database dependencies not available"},
)
return
if self._prune_first:
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d cache states as missing before scan", marked)
if self._is_cancelled():
logging.info("Asset scan cancelled after pruning phase")
cancelled = True
return
self._log_scan_config(roots)
existing_paths: set[str] = set()
for r in roots:
if self._is_cancelled():
logging.info("Asset scan cancelled during sync phase")
cancelled = True
return
existing_paths.update(sync_root_safely(r))
if self._is_cancelled():
logging.info("Asset scan cancelled after sync phase")
cancelled = True
return
paths = collect_paths_for_roots(roots)
total_paths = len(paths)
self._update_progress(total=total_paths)
self._emit_event(
"assets.seed.started",
{"roots": list(roots), "total": total_paths},
)
specs, tag_pool, skipped_existing = build_asset_specs(
paths, existing_paths, compute_hashes=self._compute_hashes
)
self._update_progress(skipped=skipped_existing)
if self._is_cancelled():
logging.info("Asset scan cancelled after building specs")
cancelled = True
return
batch_size = 500
last_progress_time = time.perf_counter()
progress_interval = 1.0
for i in range(0, len(specs), batch_size):
if self._is_cancelled():
logging.info(
"Asset scan cancelled after %d/%d files (created=%d)",
i,
len(specs),
total_created,
)
cancelled = True
return
batch = specs[i : i + batch_size]
batch_tags = {t for spec in batch for t in spec["tags"]}
try:
created = insert_asset_specs(batch, batch_tags)
total_created += created
except Exception as e:
self._add_error(f"Batch insert failed at offset {i}: {e}")
logging.exception("Batch insert failed at offset %d", i)
scanned = i + len(batch)
now = time.perf_counter()
self._update_progress(scanned=scanned, created=total_created)
if now - last_progress_time >= progress_interval:
self._emit_event(
"assets.seed.progress",
{
"scanned": scanned,
"total": len(specs),
"created": total_created,
},
)
last_progress_time = now
self._update_progress(scanned=len(specs), created=total_created)
elapsed = time.perf_counter() - t_start
logging.info(
"Asset scan(roots=%s) completed in %.3fs (created=%d, skipped=%d, total=%d)",
roots,
elapsed,
total_created,
skipped_existing,
len(paths),
)
self._emit_event(
"assets.seed.completed",
{
"scanned": len(specs),
"total": total_paths,
"created": total_created,
"skipped": skipped_existing,
"elapsed": round(elapsed, 3),
},
)
except Exception as e:
self._add_error(f"Scan failed: {e}")
logging.exception("Asset scan failed")
self._emit_event("assets.seed.error", {"message": str(e)})
finally:
if cancelled:
self._emit_event(
"assets.seed.cancelled",
{
"scanned": self._progress.scanned if self._progress else 0,
"total": total_paths,
"created": total_created,
},
)
with self._lock:
self._state = State.IDLE
asset_seeder = AssetSeeder()