mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-29 21:13:33 +08:00
feat(assets): register output files as assets after prompt execution (#12812)
This commit is contained in:
parent
5ebb0c2e0b
commit
7d5534d8e5
@ -1,6 +1,7 @@
|
|||||||
from app.assets.database.queries.asset import (
|
from app.assets.database.queries.asset import (
|
||||||
asset_exists_by_hash,
|
asset_exists_by_hash,
|
||||||
bulk_insert_assets,
|
bulk_insert_assets,
|
||||||
|
create_stub_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_existing_asset_ids,
|
get_existing_asset_ids,
|
||||||
reassign_asset_references,
|
reassign_asset_references,
|
||||||
@ -12,6 +13,7 @@ from app.assets.database.queries.asset_reference import (
|
|||||||
UnenrichedReferenceRow,
|
UnenrichedReferenceRow,
|
||||||
bulk_insert_references_ignore_conflicts,
|
bulk_insert_references_ignore_conflicts,
|
||||||
bulk_update_enrichment_level,
|
bulk_update_enrichment_level,
|
||||||
|
count_active_siblings,
|
||||||
bulk_update_is_missing,
|
bulk_update_is_missing,
|
||||||
bulk_update_needs_verify,
|
bulk_update_needs_verify,
|
||||||
convert_metadata_to_rows,
|
convert_metadata_to_rows,
|
||||||
@ -80,6 +82,8 @@ __all__ = [
|
|||||||
"bulk_insert_references_ignore_conflicts",
|
"bulk_insert_references_ignore_conflicts",
|
||||||
"bulk_insert_tags_and_meta",
|
"bulk_insert_tags_and_meta",
|
||||||
"bulk_update_enrichment_level",
|
"bulk_update_enrichment_level",
|
||||||
|
"count_active_siblings",
|
||||||
|
"create_stub_asset",
|
||||||
"bulk_update_is_missing",
|
"bulk_update_is_missing",
|
||||||
"bulk_update_needs_verify",
|
"bulk_update_needs_verify",
|
||||||
"convert_metadata_to_rows",
|
"convert_metadata_to_rows",
|
||||||
|
|||||||
@ -78,6 +78,18 @@ def upsert_asset(
|
|||||||
return asset, created, updated
|
return asset, created, updated
|
||||||
|
|
||||||
|
|
||||||
|
def create_stub_asset(
|
||||||
|
session: Session,
|
||||||
|
size_bytes: int,
|
||||||
|
mime_type: str | None = None,
|
||||||
|
) -> Asset:
|
||||||
|
"""Create a new asset with no hash (stub for later enrichment)."""
|
||||||
|
asset = Asset(size_bytes=size_bytes, mime_type=mime_type, hash=None)
|
||||||
|
session.add(asset)
|
||||||
|
session.flush()
|
||||||
|
return asset
|
||||||
|
|
||||||
|
|
||||||
def bulk_insert_assets(
|
def bulk_insert_assets(
|
||||||
session: Session,
|
session: Session,
|
||||||
rows: list[dict],
|
rows: list[dict],
|
||||||
|
|||||||
@ -114,6 +114,23 @@ def get_reference_by_file_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def count_active_siblings(
|
||||||
|
session: Session,
|
||||||
|
asset_id: str,
|
||||||
|
exclude_reference_id: str,
|
||||||
|
) -> int:
|
||||||
|
"""Count active (non-deleted) references to an asset, excluding one reference."""
|
||||||
|
return (
|
||||||
|
session.query(AssetReference)
|
||||||
|
.filter(
|
||||||
|
AssetReference.asset_id == asset_id,
|
||||||
|
AssetReference.id != exclude_reference_id,
|
||||||
|
AssetReference.deleted_at.is_(None),
|
||||||
|
)
|
||||||
|
.count()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def reference_exists_for_asset_id(
|
def reference_exists_for_asset_id(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from app.assets.database.queries import (
|
|||||||
delete_references_by_ids,
|
delete_references_by_ids,
|
||||||
ensure_tags_exist,
|
ensure_tags_exist,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
|
get_reference_by_id,
|
||||||
get_references_for_prefixes,
|
get_references_for_prefixes,
|
||||||
get_unenriched_references,
|
get_unenriched_references,
|
||||||
mark_references_missing_outside_prefixes,
|
mark_references_missing_outside_prefixes,
|
||||||
@ -338,6 +339,7 @@ def build_asset_specs(
|
|||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
"hash": asset_hash,
|
"hash": asset_hash,
|
||||||
"mime_type": mime_type,
|
"mime_type": mime_type,
|
||||||
|
"job_id": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
tag_pool.update(tags)
|
tag_pool.update(tags)
|
||||||
@ -426,6 +428,7 @@ def enrich_asset(
|
|||||||
except OSError:
|
except OSError:
|
||||||
return new_level
|
return new_level
|
||||||
|
|
||||||
|
initial_mtime_ns = get_mtime_ns(stat_p)
|
||||||
rel_fname = compute_relative_filename(file_path)
|
rel_fname = compute_relative_filename(file_path)
|
||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
metadata = None
|
metadata = None
|
||||||
@ -489,6 +492,18 @@ def enrich_asset(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||||
|
|
||||||
|
# Optimistic guard: if the reference's mtime_ns changed since we
|
||||||
|
# started (e.g. ingest_existing_file updated it), our results are
|
||||||
|
# stale — discard them to avoid overwriting fresh registration data.
|
||||||
|
ref = get_reference_by_id(session, reference_id)
|
||||||
|
if ref is None or ref.mtime_ns != initial_mtime_ns:
|
||||||
|
session.rollback()
|
||||||
|
logging.info(
|
||||||
|
"Ref %s mtime changed during enrichment, discarding stale result",
|
||||||
|
reference_id,
|
||||||
|
)
|
||||||
|
return ENRICHMENT_STUB
|
||||||
|
|
||||||
if extract_metadata and metadata:
|
if extract_metadata and metadata:
|
||||||
system_metadata = metadata.to_user_metadata()
|
system_metadata = metadata.to_user_metadata()
|
||||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||||
|
|||||||
@ -77,7 +77,9 @@ class _AssetSeeder:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
self._lock = threading.Lock()
|
# RLock is required because _run_scan() drains pending work while
|
||||||
|
# holding _lock and re-enters start() which also acquires _lock.
|
||||||
|
self._lock = threading.RLock()
|
||||||
self._state = State.IDLE
|
self._state = State.IDLE
|
||||||
self._progress: Progress | None = None
|
self._progress: Progress | None = None
|
||||||
self._last_progress: Progress | None = None
|
self._last_progress: Progress | None = None
|
||||||
@ -92,6 +94,7 @@ class _AssetSeeder:
|
|||||||
self._prune_first: bool = False
|
self._prune_first: bool = False
|
||||||
self._progress_callback: ProgressCallback | None = None
|
self._progress_callback: ProgressCallback | None = None
|
||||||
self._disabled: bool = False
|
self._disabled: bool = False
|
||||||
|
self._pending_enrich: dict | None = None
|
||||||
|
|
||||||
def disable(self) -> None:
|
def disable(self) -> None:
|
||||||
"""Disable the asset seeder, preventing any scans from starting."""
|
"""Disable the asset seeder, preventing any scans from starting."""
|
||||||
@ -196,6 +199,42 @@ class _AssetSeeder:
|
|||||||
compute_hashes=compute_hashes,
|
compute_hashes=compute_hashes,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def enqueue_enrich(
|
||||||
|
self,
|
||||||
|
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||||
|
compute_hashes: bool = False,
|
||||||
|
) -> bool:
|
||||||
|
"""Start an enrichment scan now, or queue it for after the current scan.
|
||||||
|
|
||||||
|
If the seeder is idle, starts immediately. Otherwise, the enrich
|
||||||
|
request is stored and will run automatically when the current scan
|
||||||
|
finishes.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
roots: Tuple of root types to scan
|
||||||
|
compute_hashes: If True, compute blake3 hashes
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True if started immediately, False if queued for later
|
||||||
|
"""
|
||||||
|
with self._lock:
|
||||||
|
if self.start_enrich(roots=roots, compute_hashes=compute_hashes):
|
||||||
|
return True
|
||||||
|
if self._pending_enrich is not None:
|
||||||
|
existing_roots = set(self._pending_enrich["roots"])
|
||||||
|
existing_roots.update(roots)
|
||||||
|
self._pending_enrich["roots"] = tuple(existing_roots)
|
||||||
|
self._pending_enrich["compute_hashes"] = (
|
||||||
|
self._pending_enrich["compute_hashes"] or compute_hashes
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self._pending_enrich = {
|
||||||
|
"roots": roots,
|
||||||
|
"compute_hashes": compute_hashes,
|
||||||
|
}
|
||||||
|
logging.info("Enrich scan queued (roots=%s)", self._pending_enrich["roots"])
|
||||||
|
return False
|
||||||
|
|
||||||
def cancel(self) -> bool:
|
def cancel(self) -> bool:
|
||||||
"""Request cancellation of the current scan.
|
"""Request cancellation of the current scan.
|
||||||
|
|
||||||
@ -381,9 +420,13 @@ class _AssetSeeder:
|
|||||||
return marked
|
return marked
|
||||||
finally:
|
finally:
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._last_progress = self._progress
|
self._reset_to_idle()
|
||||||
self._state = State.IDLE
|
|
||||||
self._progress = None
|
def _reset_to_idle(self) -> None:
|
||||||
|
"""Reset state to IDLE, preserving last progress. Caller must hold _lock."""
|
||||||
|
self._last_progress = self._progress
|
||||||
|
self._state = State.IDLE
|
||||||
|
self._progress = None
|
||||||
|
|
||||||
def _is_cancelled(self) -> bool:
|
def _is_cancelled(self) -> bool:
|
||||||
"""Check if cancellation has been requested."""
|
"""Check if cancellation has been requested."""
|
||||||
@ -594,9 +637,18 @@ class _AssetSeeder:
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
with self._lock:
|
with self._lock:
|
||||||
self._last_progress = self._progress
|
self._reset_to_idle()
|
||||||
self._state = State.IDLE
|
pending = self._pending_enrich
|
||||||
self._progress = None
|
if pending is not None:
|
||||||
|
self._pending_enrich = None
|
||||||
|
if not self.start_enrich(
|
||||||
|
roots=pending["roots"],
|
||||||
|
compute_hashes=pending["compute_hashes"],
|
||||||
|
):
|
||||||
|
logging.warning(
|
||||||
|
"Pending enrich scan could not start (roots=%s)",
|
||||||
|
pending["roots"],
|
||||||
|
)
|
||||||
|
|
||||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||||
"""Run phase 1: fast scan to create stub records.
|
"""Run phase 1: fast scan to create stub records.
|
||||||
|
|||||||
@ -23,6 +23,8 @@ from app.assets.services.ingest import (
|
|||||||
DependencyMissingError,
|
DependencyMissingError,
|
||||||
HashMismatchError,
|
HashMismatchError,
|
||||||
create_from_hash,
|
create_from_hash,
|
||||||
|
ingest_existing_file,
|
||||||
|
register_output_files,
|
||||||
upload_from_temp_path,
|
upload_from_temp_path,
|
||||||
)
|
)
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
@ -72,6 +74,8 @@ __all__ = [
|
|||||||
"delete_asset_reference",
|
"delete_asset_reference",
|
||||||
"get_asset_by_hash",
|
"get_asset_by_hash",
|
||||||
"get_asset_detail",
|
"get_asset_detail",
|
||||||
|
"ingest_existing_file",
|
||||||
|
"register_output_files",
|
||||||
"get_mtime_ns",
|
"get_mtime_ns",
|
||||||
"get_size_and_mtime_ns",
|
"get_size_and_mtime_ns",
|
||||||
"list_assets_page",
|
"list_assets_page",
|
||||||
|
|||||||
@ -37,6 +37,7 @@ class SeedAssetSpec(TypedDict):
|
|||||||
metadata: ExtractedMetadata | None
|
metadata: ExtractedMetadata | None
|
||||||
hash: str | None
|
hash: str | None
|
||||||
mime_type: str | None
|
mime_type: str | None
|
||||||
|
job_id: str | None
|
||||||
|
|
||||||
|
|
||||||
class AssetRow(TypedDict):
|
class AssetRow(TypedDict):
|
||||||
@ -60,6 +61,7 @@ class ReferenceRow(TypedDict):
|
|||||||
name: str
|
name: str
|
||||||
preview_id: str | None
|
preview_id: str | None
|
||||||
user_metadata: dict[str, Any] | None
|
user_metadata: dict[str, Any] | None
|
||||||
|
job_id: str | None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
last_access_time: datetime
|
last_access_time: datetime
|
||||||
@ -167,6 +169,7 @@ def batch_insert_seed_assets(
|
|||||||
"name": spec["info_name"],
|
"name": spec["info_name"],
|
||||||
"preview_id": None,
|
"preview_id": None,
|
||||||
"user_metadata": user_metadata,
|
"user_metadata": user_metadata,
|
||||||
|
"job_id": spec.get("job_id"),
|
||||||
"created_at": current_time,
|
"created_at": current_time,
|
||||||
"updated_at": current_time,
|
"updated_at": current_time,
|
||||||
"last_access_time": current_time,
|
"last_access_time": current_time,
|
||||||
|
|||||||
@ -9,6 +9,9 @@ from sqlalchemy.orm import Session
|
|||||||
import app.assets.services.hashing as hashing
|
import app.assets.services.hashing as hashing
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
add_tags_to_reference,
|
add_tags_to_reference,
|
||||||
|
count_active_siblings,
|
||||||
|
create_stub_asset,
|
||||||
|
ensure_tags_exist,
|
||||||
fetch_reference_and_asset,
|
fetch_reference_and_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_reference_by_file_path,
|
get_reference_by_file_path,
|
||||||
@ -23,7 +26,8 @@ from app.assets.database.queries import (
|
|||||||
upsert_reference,
|
upsert_reference,
|
||||||
validate_tags_exist,
|
validate_tags_exist,
|
||||||
)
|
)
|
||||||
from app.assets.helpers import normalize_tags
|
from app.assets.helpers import get_utc_now, normalize_tags
|
||||||
|
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||||
from app.assets.services.path_utils import (
|
from app.assets.services.path_utils import (
|
||||||
compute_relative_filename,
|
compute_relative_filename,
|
||||||
@ -130,6 +134,102 @@ def _ingest_file_from_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_output_files(
|
||||||
|
file_paths: Sequence[str],
|
||||||
|
user_metadata: UserMetadata = None,
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> int:
|
||||||
|
"""Register a batch of output file paths as assets.
|
||||||
|
|
||||||
|
Returns the number of files successfully registered.
|
||||||
|
"""
|
||||||
|
registered = 0
|
||||||
|
for abs_path in file_paths:
|
||||||
|
if not os.path.isfile(abs_path):
|
||||||
|
continue
|
||||||
|
try:
|
||||||
|
if ingest_existing_file(
|
||||||
|
abs_path, user_metadata=user_metadata, job_id=job_id
|
||||||
|
):
|
||||||
|
registered += 1
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to register output: %s", abs_path)
|
||||||
|
return registered
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_existing_file(
|
||||||
|
abs_path: str,
|
||||||
|
user_metadata: UserMetadata = None,
|
||||||
|
extra_tags: Sequence[str] = (),
|
||||||
|
owner_id: str = "",
|
||||||
|
job_id: str | None = None,
|
||||||
|
) -> bool:
|
||||||
|
"""Register an existing on-disk file as an asset stub.
|
||||||
|
|
||||||
|
If a reference already exists for this path, updates mtime_ns, job_id,
|
||||||
|
size_bytes, and resets enrichment so the enricher will re-hash it.
|
||||||
|
|
||||||
|
For brand-new paths, inserts a stub record (hash=NULL) for immediate
|
||||||
|
UX visibility.
|
||||||
|
|
||||||
|
Returns True if a row was inserted or updated, False otherwise.
|
||||||
|
"""
|
||||||
|
locator = os.path.abspath(abs_path)
|
||||||
|
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||||
|
mime_type = mimetypes.guess_type(abs_path, strict=False)[0]
|
||||||
|
name, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||||
|
tags = list(dict.fromkeys(path_tags + list(extra_tags)))
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
existing_ref = get_reference_by_file_path(session, locator)
|
||||||
|
if existing_ref is not None:
|
||||||
|
now = get_utc_now()
|
||||||
|
existing_ref.mtime_ns = mtime_ns
|
||||||
|
existing_ref.job_id = job_id
|
||||||
|
existing_ref.is_missing = False
|
||||||
|
existing_ref.deleted_at = None
|
||||||
|
existing_ref.updated_at = now
|
||||||
|
existing_ref.enrichment_level = 0
|
||||||
|
|
||||||
|
asset = existing_ref.asset
|
||||||
|
if asset:
|
||||||
|
# If other refs share this asset, detach to a new stub
|
||||||
|
# instead of mutating the shared row.
|
||||||
|
siblings = count_active_siblings(session, asset.id, existing_ref.id)
|
||||||
|
if siblings > 0:
|
||||||
|
new_asset = create_stub_asset(
|
||||||
|
session,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mime_type=mime_type or asset.mime_type,
|
||||||
|
)
|
||||||
|
existing_ref.asset_id = new_asset.id
|
||||||
|
else:
|
||||||
|
asset.hash = None
|
||||||
|
asset.size_bytes = size_bytes
|
||||||
|
if mime_type:
|
||||||
|
asset.mime_type = mime_type
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
spec = {
|
||||||
|
"abs_path": abs_path,
|
||||||
|
"size_bytes": size_bytes,
|
||||||
|
"mtime_ns": mtime_ns,
|
||||||
|
"info_name": name,
|
||||||
|
"tags": tags,
|
||||||
|
"fname": os.path.basename(abs_path),
|
||||||
|
"metadata": None,
|
||||||
|
"hash": None,
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"job_id": job_id,
|
||||||
|
}
|
||||||
|
if tags:
|
||||||
|
ensure_tags_exist(session, tags)
|
||||||
|
result = batch_insert_seed_assets(session, [spec], owner_id=owner_id)
|
||||||
|
session.commit()
|
||||||
|
return result.won_paths > 0
|
||||||
|
|
||||||
|
|
||||||
def _register_existing_asset(
|
def _register_existing_asset(
|
||||||
asset_hash: str,
|
asset_hash: str,
|
||||||
name: str,
|
name: str,
|
||||||
|
|||||||
43
main.py
43
main.py
@ -9,6 +9,8 @@ import folder_paths
|
|||||||
import time
|
import time
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args, enables_dynamic_vram
|
||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
|
from app.assets.seeder import asset_seeder
|
||||||
|
from app.assets.services import register_output_files
|
||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
from utils.mime_types import init_mime_types
|
from utils.mime_types import init_mime_types
|
||||||
@ -192,7 +194,6 @@ if 'torch' in sys.modules:
|
|||||||
|
|
||||||
|
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
from app.assets.seeder import asset_seeder
|
|
||||||
|
|
||||||
import execution
|
import execution
|
||||||
import server
|
import server
|
||||||
@ -240,6 +241,38 @@ def cuda_malloc_warning():
|
|||||||
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
logging.warning("\nWARNING: this card most likely does not support cuda-malloc, if you get \"CUDA error\" please run ComfyUI with: --disable-cuda-malloc\n")
|
||||||
|
|
||||||
|
|
||||||
|
def _collect_output_absolute_paths(history_result: dict) -> list[str]:
|
||||||
|
"""Extract absolute file paths for output items from a history result."""
|
||||||
|
paths: list[str] = []
|
||||||
|
seen: set[str] = set()
|
||||||
|
for node_output in history_result.get("outputs", {}).values():
|
||||||
|
for items in node_output.values():
|
||||||
|
if not isinstance(items, list):
|
||||||
|
continue
|
||||||
|
for item in items:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
item_type = item.get("type")
|
||||||
|
if item_type not in ("output", "temp"):
|
||||||
|
continue
|
||||||
|
base_dir = folder_paths.get_directory_by_type(item_type)
|
||||||
|
if base_dir is None:
|
||||||
|
continue
|
||||||
|
base_dir = os.path.abspath(base_dir)
|
||||||
|
filename = item.get("filename")
|
||||||
|
if not filename:
|
||||||
|
continue
|
||||||
|
abs_path = os.path.abspath(
|
||||||
|
os.path.join(base_dir, item.get("subfolder", ""), filename)
|
||||||
|
)
|
||||||
|
if not abs_path.startswith(base_dir + os.sep) and abs_path != base_dir:
|
||||||
|
continue
|
||||||
|
if abs_path not in seen:
|
||||||
|
seen.add(abs_path)
|
||||||
|
paths.append(abs_path)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
def prompt_worker(q, server_instance):
|
def prompt_worker(q, server_instance):
|
||||||
current_time: float = 0.0
|
current_time: float = 0.0
|
||||||
cache_type = execution.CacheType.CLASSIC
|
cache_type = execution.CacheType.CLASSIC
|
||||||
@ -274,6 +307,7 @@ def prompt_worker(q, server_instance):
|
|||||||
|
|
||||||
asset_seeder.pause()
|
asset_seeder.pause()
|
||||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||||
|
|
||||||
need_gc = True
|
need_gc = True
|
||||||
|
|
||||||
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
remove_sensitive = lambda prompt: prompt[:5] + prompt[6:]
|
||||||
@ -296,6 +330,10 @@ def prompt_worker(q, server_instance):
|
|||||||
else:
|
else:
|
||||||
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
logging.info("Prompt executed in {:.2f} seconds".format(execution_time))
|
||||||
|
|
||||||
|
if not asset_seeder.is_disabled():
|
||||||
|
paths = _collect_output_absolute_paths(e.history_result)
|
||||||
|
register_output_files(paths, job_id=prompt_id)
|
||||||
|
|
||||||
flags = q.get_flags()
|
flags = q.get_flags()
|
||||||
free_memory = flags.get("free_memory", False)
|
free_memory = flags.get("free_memory", False)
|
||||||
|
|
||||||
@ -317,6 +355,9 @@ def prompt_worker(q, server_instance):
|
|||||||
last_gc_collect = current_time
|
last_gc_collect = current_time
|
||||||
need_gc = False
|
need_gc = False
|
||||||
hook_breaker_ac10a0.restore_functions()
|
hook_breaker_ac10a0.restore_functions()
|
||||||
|
|
||||||
|
if not asset_seeder.is_disabled():
|
||||||
|
asset_seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
asset_seeder.resume()
|
asset_seeder.resume()
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -3,7 +3,7 @@ from pathlib import Path
|
|||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy import create_engine
|
from sqlalchemy import create_engine, event
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import Base
|
from app.assets.database.models import Base
|
||||||
@ -23,6 +23,21 @@ def db_engine():
|
|||||||
return engine
|
return engine
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def db_engine_fk():
|
||||||
|
"""In-memory SQLite engine with foreign key enforcement enabled."""
|
||||||
|
engine = create_engine("sqlite:///:memory:")
|
||||||
|
|
||||||
|
@event.listens_for(engine, "connect")
|
||||||
|
def _set_pragma(dbapi_connection, connection_record):
|
||||||
|
cursor = dbapi_connection.cursor()
|
||||||
|
cursor.execute("PRAGMA foreign_keys=ON")
|
||||||
|
cursor.close()
|
||||||
|
|
||||||
|
Base.metadata.create_all(engine)
|
||||||
|
return engine
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def session(db_engine):
|
def session(db_engine):
|
||||||
"""Session fixture for tests that need direct DB access."""
|
"""Session fixture for tests that need direct DB access."""
|
||||||
|
|||||||
@ -1,9 +1,11 @@
|
|||||||
"""Tests for asset enrichment (mime_type and hash population)."""
|
"""Tests for asset enrichment (mime_type and hash population)."""
|
||||||
|
import os
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import Asset, AssetReference
|
from app.assets.database.models import Asset, AssetReference
|
||||||
|
from app.assets.services.file_utils import get_mtime_ns
|
||||||
from app.assets.scanner import (
|
from app.assets.scanner import (
|
||||||
ENRICHMENT_HASHED,
|
ENRICHMENT_HASHED,
|
||||||
ENRICHMENT_METADATA,
|
ENRICHMENT_METADATA,
|
||||||
@ -20,6 +22,13 @@ def _create_stub_asset(
|
|||||||
name: str | None = None,
|
name: str | None = None,
|
||||||
) -> tuple[Asset, AssetReference]:
|
) -> tuple[Asset, AssetReference]:
|
||||||
"""Create a stub asset with reference for testing enrichment."""
|
"""Create a stub asset with reference for testing enrichment."""
|
||||||
|
# Use the real file's mtime so the optimistic guard in enrich_asset passes
|
||||||
|
try:
|
||||||
|
stat_result = os.stat(file_path, follow_symlinks=True)
|
||||||
|
mtime_ns = get_mtime_ns(stat_result)
|
||||||
|
except OSError:
|
||||||
|
mtime_ns = 1234567890000000000
|
||||||
|
|
||||||
asset = Asset(
|
asset = Asset(
|
||||||
id=asset_id,
|
id=asset_id,
|
||||||
hash=None,
|
hash=None,
|
||||||
@ -35,7 +44,7 @@ def _create_stub_asset(
|
|||||||
name=name or f"test-asset-{asset_id}",
|
name=name or f"test-asset-{asset_id}",
|
||||||
owner_id="system",
|
owner_id="system",
|
||||||
file_path=file_path,
|
file_path=file_path,
|
||||||
mtime_ns=1234567890000000000,
|
mtime_ns=mtime_ns,
|
||||||
enrichment_level=ENRICHMENT_STUB,
|
enrichment_level=ENRICHMENT_STUB,
|
||||||
)
|
)
|
||||||
session.add(ref)
|
session.add(ref)
|
||||||
|
|||||||
@ -1,12 +1,18 @@
|
|||||||
"""Tests for ingest services."""
|
"""Tests for ingest services."""
|
||||||
|
from contextlib import contextmanager
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session as SASession, Session
|
||||||
|
|
||||||
from app.assets.database.models import Asset, AssetReference, Tag
|
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
|
||||||
from app.assets.database.queries import get_reference_tags
|
from app.assets.database.queries import get_reference_tags
|
||||||
from app.assets.services.ingest import _ingest_file_from_path, _register_existing_asset
|
from app.assets.services.ingest import (
|
||||||
|
_ingest_file_from_path,
|
||||||
|
_register_existing_asset,
|
||||||
|
ingest_existing_file,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class TestIngestFileFromPath:
|
class TestIngestFileFromPath:
|
||||||
@ -235,3 +241,42 @@ class TestRegisterExistingAsset:
|
|||||||
|
|
||||||
assert result.created is True
|
assert result.created is True
|
||||||
assert set(result.tags) == {"alpha", "beta"}
|
assert set(result.tags) == {"alpha", "beta"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestIngestExistingFileTagFK:
|
||||||
|
"""Regression: ingest_existing_file must seed Tag rows before inserting
|
||||||
|
AssetReferenceTag rows, otherwise FK enforcement raises IntegrityError."""
|
||||||
|
|
||||||
|
def test_creates_tag_rows_before_reference_tags(self, db_engine_fk, temp_dir: Path):
|
||||||
|
"""With PRAGMA foreign_keys=ON, tags must exist in the tags table
|
||||||
|
before they can be referenced in asset_reference_tags."""
|
||||||
|
|
||||||
|
@contextmanager
|
||||||
|
def _create_session():
|
||||||
|
with SASession(db_engine_fk) as sess:
|
||||||
|
yield sess
|
||||||
|
|
||||||
|
file_path = temp_dir / "output.png"
|
||||||
|
file_path.write_bytes(b"image data")
|
||||||
|
|
||||||
|
with patch("app.assets.services.ingest.create_session", _create_session), \
|
||||||
|
patch(
|
||||||
|
"app.assets.services.ingest.get_name_and_tags_from_asset_path",
|
||||||
|
return_value=("output.png", ["output"]),
|
||||||
|
):
|
||||||
|
result = ingest_existing_file(
|
||||||
|
abs_path=str(file_path),
|
||||||
|
extra_tags=["my-job"],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert result is True
|
||||||
|
|
||||||
|
with SASession(db_engine_fk) as sess:
|
||||||
|
tag_names = {t.name for t in sess.query(Tag).all()}
|
||||||
|
assert "output" in tag_names
|
||||||
|
assert "my-job" in tag_names
|
||||||
|
|
||||||
|
ref_tags = sess.query(AssetReferenceTag).all()
|
||||||
|
ref_tag_names = {rt.tag_name for rt in ref_tags}
|
||||||
|
assert "output" in ref_tag_names
|
||||||
|
assert "my-job" in ref_tag_names
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
"""Unit tests for the _AssetSeeder background scanning class."""
|
"""Unit tests for the _AssetSeeder background scanning class."""
|
||||||
|
|
||||||
import threading
|
import threading
|
||||||
|
import time
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
@ -771,6 +772,188 @@ class TestSeederStopRestart:
|
|||||||
assert collected_roots[1] == ("input",)
|
assert collected_roots[1] == ("input",)
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichHandoff:
|
||||||
|
"""Test that the drain of _pending_enrich is atomic with start_enrich."""
|
||||||
|
|
||||||
|
def test_pending_enrich_runs_after_scan_completes(
|
||||||
|
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
"""A queued enrich request runs automatically when a scan finishes."""
|
||||||
|
enrich_roots_seen: list[tuple] = []
|
||||||
|
original_start = fresh_seeder.start
|
||||||
|
|
||||||
|
def tracking_start(*args, **kwargs):
|
||||||
|
phase = kwargs.get("phase")
|
||||||
|
roots = kwargs.get("roots", args[0] if args else None)
|
||||||
|
result = original_start(*args, **kwargs)
|
||||||
|
if phase == ScanPhase.ENRICH and result:
|
||||||
|
enrich_roots_seen.append(roots)
|
||||||
|
return result
|
||||||
|
|
||||||
|
fresh_seeder.start = tracking_start
|
||||||
|
|
||||||
|
# Start a fast scan, then enqueue an enrich while it's running
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
queued = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("input",), compute_hashes=True
|
||||||
|
)
|
||||||
|
assert queued is False # queued, not started immediately
|
||||||
|
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
# Wait for the original scan + the auto-started enrich scan
|
||||||
|
deadline = time.monotonic() + 5.0
|
||||||
|
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
assert enrich_roots_seen == [("input",)]
|
||||||
|
|
||||||
|
def test_enqueue_enrich_during_drain_does_not_lose_work(
|
||||||
|
self, fresh_seeder: _AssetSeeder, mock_dependencies
|
||||||
|
):
|
||||||
|
"""enqueue_enrich called concurrently with drain cannot drop work.
|
||||||
|
|
||||||
|
Simulates the race: another thread calls enqueue_enrich right as the
|
||||||
|
scan thread is draining _pending_enrich. The enqueue must either be
|
||||||
|
picked up by the draining scan or successfully start its own scan.
|
||||||
|
"""
|
||||||
|
barrier = threading.Event()
|
||||||
|
reached = threading.Event()
|
||||||
|
enrich_started = threading.Event()
|
||||||
|
|
||||||
|
enrich_call_count = 0
|
||||||
|
|
||||||
|
def slow_collect(*args):
|
||||||
|
reached.set()
|
||||||
|
barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
# Track how many times start_enrich actually fires
|
||||||
|
real_start_enrich = fresh_seeder.start_enrich
|
||||||
|
enrich_roots_seen: list[tuple] = []
|
||||||
|
|
||||||
|
def tracking_start_enrich(**kwargs):
|
||||||
|
nonlocal enrich_call_count
|
||||||
|
enrich_call_count += 1
|
||||||
|
enrich_roots_seen.append(kwargs.get("roots"))
|
||||||
|
result = real_start_enrich(**kwargs)
|
||||||
|
if result:
|
||||||
|
enrich_started.set()
|
||||||
|
return result
|
||||||
|
|
||||||
|
fresh_seeder.start_enrich = tracking_start_enrich
|
||||||
|
|
||||||
|
with patch(
|
||||||
|
"app.assets.seeder.collect_paths_for_roots", side_effect=slow_collect
|
||||||
|
):
|
||||||
|
# Start a scan
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
# Queue an enrich while scan is running
|
||||||
|
fresh_seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
# Let scan finish — drain will fire start_enrich atomically
|
||||||
|
barrier.set()
|
||||||
|
|
||||||
|
# Wait for drain to complete and the enrich scan to start
|
||||||
|
assert enrich_started.wait(timeout=5.0), "Enrich scan was never started from drain"
|
||||||
|
assert ("output",) in enrich_roots_seen
|
||||||
|
|
||||||
|
def test_concurrent_enqueue_during_drain_not_lost(
|
||||||
|
self, fresh_seeder: _AssetSeeder,
|
||||||
|
):
|
||||||
|
"""A second enqueue_enrich arriving while drain is in progress is not lost.
|
||||||
|
|
||||||
|
Because the drain now holds _lock through the start_enrich call,
|
||||||
|
a concurrent enqueue_enrich will block until start_enrich has
|
||||||
|
transitioned state to RUNNING, then the enqueue will queue its
|
||||||
|
payload as _pending_enrich for the *next* drain.
|
||||||
|
"""
|
||||||
|
scan_barrier = threading.Event()
|
||||||
|
scan_reached = threading.Event()
|
||||||
|
enrich_barrier = threading.Event()
|
||||||
|
enrich_reached = threading.Event()
|
||||||
|
|
||||||
|
collect_call = 0
|
||||||
|
|
||||||
|
def gated_collect(*args):
|
||||||
|
nonlocal collect_call
|
||||||
|
collect_call += 1
|
||||||
|
if collect_call == 1:
|
||||||
|
# First call: the initial fast scan
|
||||||
|
scan_reached.set()
|
||||||
|
scan_barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
enrich_call = 0
|
||||||
|
|
||||||
|
def gated_get_unenriched(*args, **kwargs):
|
||||||
|
nonlocal enrich_call
|
||||||
|
enrich_call += 1
|
||||||
|
if enrich_call == 1:
|
||||||
|
# First enrich batch: signal and block
|
||||||
|
enrich_reached.set()
|
||||||
|
enrich_barrier.wait(timeout=5.0)
|
||||||
|
return []
|
||||||
|
|
||||||
|
with (
|
||||||
|
patch("app.assets.seeder.dependencies_available", return_value=True),
|
||||||
|
patch("app.assets.seeder.sync_root_safely", return_value=set()),
|
||||||
|
patch("app.assets.seeder.collect_paths_for_roots", side_effect=gated_collect),
|
||||||
|
patch("app.assets.seeder.build_asset_specs", return_value=([], set(), 0)),
|
||||||
|
patch("app.assets.seeder.insert_asset_specs", return_value=0),
|
||||||
|
patch("app.assets.seeder.get_unenriched_assets_for_roots", side_effect=gated_get_unenriched),
|
||||||
|
patch("app.assets.seeder.enrich_assets_batch", return_value=(0, 0)),
|
||||||
|
):
|
||||||
|
# 1. Start fast scan
|
||||||
|
fresh_seeder.start(roots=("models",), phase=ScanPhase.FAST)
|
||||||
|
assert scan_reached.wait(timeout=2.0)
|
||||||
|
|
||||||
|
# 2. Queue enrich while fast scan is running
|
||||||
|
queued = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("input",), compute_hashes=False
|
||||||
|
)
|
||||||
|
assert queued is False
|
||||||
|
|
||||||
|
# 3. Let the fast scan finish — drain will start the enrich scan
|
||||||
|
scan_barrier.set()
|
||||||
|
|
||||||
|
# 4. Wait until the drained enrich scan is running
|
||||||
|
assert enrich_reached.wait(timeout=5.0)
|
||||||
|
|
||||||
|
# 5. Now enqueue another enrich while the drained scan is running
|
||||||
|
queued2 = fresh_seeder.enqueue_enrich(
|
||||||
|
roots=("output",), compute_hashes=True
|
||||||
|
)
|
||||||
|
assert queued2 is False # should be queued, not started
|
||||||
|
|
||||||
|
# Verify _pending_enrich was set (the second enqueue was captured)
|
||||||
|
with fresh_seeder._lock:
|
||||||
|
assert fresh_seeder._pending_enrich is not None
|
||||||
|
assert "output" in fresh_seeder._pending_enrich["roots"]
|
||||||
|
|
||||||
|
# Let the enrich scan finish
|
||||||
|
enrich_barrier.set()
|
||||||
|
|
||||||
|
deadline = time.monotonic() + 5.0
|
||||||
|
while fresh_seeder.get_status().state != State.IDLE and time.monotonic() < deadline:
|
||||||
|
time.sleep(0.05)
|
||||||
|
|
||||||
|
|
||||||
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
def _make_row(ref_id: str, asset_id: str = "a1") -> UnenrichedReferenceRow:
|
||||||
return UnenrichedReferenceRow(
|
return UnenrichedReferenceRow(
|
||||||
reference_id=ref_id, asset_id=asset_id,
|
reference_id=ref_id, asset_id=asset_id,
|
||||||
|
|||||||
250
tests/test_asset_seeder.py
Normal file
250
tests/test_asset_seeder.py
Normal file
@ -0,0 +1,250 @@
|
|||||||
|
"""Tests for app.assets.seeder – enqueue_enrich and pending-queue behaviour."""
|
||||||
|
|
||||||
|
import threading
|
||||||
|
from unittest.mock import patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from app.assets.seeder import Progress, _AssetSeeder, State
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture()
|
||||||
|
def seeder():
|
||||||
|
"""Fresh seeder instance for each test."""
|
||||||
|
return _AssetSeeder()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# _reset_to_idle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestResetToIdle:
|
||||||
|
def test_sets_idle_and_clears_progress(self, seeder):
|
||||||
|
"""_reset_to_idle should move state to IDLE and snapshot progress."""
|
||||||
|
progress = Progress(scanned=10, total=20, created=5, skipped=3)
|
||||||
|
seeder._state = State.RUNNING
|
||||||
|
seeder._progress = progress
|
||||||
|
|
||||||
|
with seeder._lock:
|
||||||
|
seeder._reset_to_idle()
|
||||||
|
|
||||||
|
assert seeder._state is State.IDLE
|
||||||
|
assert seeder._progress is None
|
||||||
|
assert seeder._last_progress is progress
|
||||||
|
|
||||||
|
def test_noop_when_progress_already_none(self, seeder):
|
||||||
|
"""_reset_to_idle should handle None progress gracefully."""
|
||||||
|
seeder._state = State.CANCELLING
|
||||||
|
seeder._progress = None
|
||||||
|
|
||||||
|
with seeder._lock:
|
||||||
|
seeder._reset_to_idle()
|
||||||
|
|
||||||
|
assert seeder._state is State.IDLE
|
||||||
|
assert seeder._progress is None
|
||||||
|
assert seeder._last_progress is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – immediate start when idle
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichStartsImmediately:
|
||||||
|
def test_starts_when_idle(self, seeder):
|
||||||
|
"""enqueue_enrich should delegate to start_enrich and return True when idle."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock:
|
||||||
|
assert seeder.enqueue_enrich(roots=("output",), compute_hashes=True) is True
|
||||||
|
mock.assert_called_once_with(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
def test_no_pending_when_started_immediately(self, seeder):
|
||||||
|
"""No pending request should be stored when start_enrich succeeds."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True):
|
||||||
|
seeder.enqueue_enrich(roots=("output",))
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – queuing when busy
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichQueuesWhenBusy:
|
||||||
|
def test_queues_when_busy(self, seeder):
|
||||||
|
"""enqueue_enrich should store a pending request when seeder is busy."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
result = seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert result is False
|
||||||
|
assert seeder._pending_enrich == {
|
||||||
|
"roots": ("models",),
|
||||||
|
"compute_hashes": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
def test_queues_preserves_compute_hashes_true(self, seeder):
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
seeder.enqueue_enrich(roots=("input",), compute_hashes=True)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# enqueue_enrich – merging when a pending request already exists
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichMergesPending:
|
||||||
|
def _make_busy(self, seeder):
|
||||||
|
"""Patch start_enrich to always return False (seeder busy)."""
|
||||||
|
return patch.object(seeder, "start_enrich", return_value=False)
|
||||||
|
|
||||||
|
def test_merges_roots(self, seeder):
|
||||||
|
"""A second enqueue should merge roots with the existing pending request."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",))
|
||||||
|
seeder.enqueue_enrich(roots=("output",))
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "output"}
|
||||||
|
|
||||||
|
def test_merges_overlapping_roots(self, seeder):
|
||||||
|
"""Duplicate roots should be deduplicated."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models", "input"))
|
||||||
|
seeder.enqueue_enrich(roots=("input", "output"))
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
|
|
||||||
|
def test_compute_hashes_sticky_true(self, seeder):
|
||||||
|
"""Once compute_hashes is True it should stay True after merging."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=True)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
def test_compute_hashes_upgrades_to_true(self, seeder):
|
||||||
|
"""A later enqueue with compute_hashes=True should upgrade the pending request."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
def test_compute_hashes_stays_false(self, seeder):
|
||||||
|
"""If both enqueues have compute_hashes=False it stays False."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=False)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is False
|
||||||
|
|
||||||
|
def test_triple_merge(self, seeder):
|
||||||
|
"""Three successive enqueues should all merge correctly."""
|
||||||
|
with self._make_busy(seeder):
|
||||||
|
seeder.enqueue_enrich(roots=("models",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("input",), compute_hashes=False)
|
||||||
|
seeder.enqueue_enrich(roots=("output",), compute_hashes=True)
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
|
assert seeder._pending_enrich["compute_hashes"] is True
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Pending enrich drains after scan completes
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestPendingEnrichDrain:
|
||||||
|
"""Verify that _run_scan drains _pending_enrich via start_enrich."""
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_pending_enrich_starts_after_scan(self, *_mocks):
|
||||||
|
"""After a fast scan finishes, the pending enrich should be started."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
|
||||||
|
seeder._pending_enrich = {
|
||||||
|
"roots": ("output",),
|
||||||
|
"compute_hashes": True,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
mock_start.assert_called_once_with(
|
||||||
|
roots=("output",),
|
||||||
|
compute_hashes=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_pending_cleared_even_when_start_fails(self, *_mocks):
|
||||||
|
"""_pending_enrich should be cleared even if start_enrich returns False."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
seeder._pending_enrich = {
|
||||||
|
"roots": ("output",),
|
||||||
|
"compute_hashes": False,
|
||||||
|
}
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
@patch("app.assets.seeder.dependencies_available", return_value=True)
|
||||||
|
@patch("app.assets.seeder.get_all_known_prefixes", return_value=[])
|
||||||
|
@patch("app.assets.seeder.sync_root_safely", return_value=set())
|
||||||
|
@patch("app.assets.seeder.collect_paths_for_roots", return_value=[])
|
||||||
|
@patch("app.assets.seeder.build_asset_specs", return_value=([], {}, 0))
|
||||||
|
def test_no_drain_when_no_pending(self, *_mocks):
|
||||||
|
"""start_enrich should not be called when there is no pending request."""
|
||||||
|
seeder = _AssetSeeder()
|
||||||
|
assert seeder._pending_enrich is None
|
||||||
|
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=True) as mock_start:
|
||||||
|
seeder.start_fast(roots=("models",))
|
||||||
|
seeder.wait(timeout=5)
|
||||||
|
|
||||||
|
mock_start.assert_not_called()
|
||||||
|
|
||||||
|
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
# Thread-safety of enqueue_enrich
|
||||||
|
# ---------------------------------------------------------------------------
|
||||||
|
|
||||||
|
|
||||||
|
class TestEnqueueEnrichThreadSafety:
|
||||||
|
def test_concurrent_enqueues(self, seeder):
|
||||||
|
"""Multiple threads enqueuing should not lose roots."""
|
||||||
|
with patch.object(seeder, "start_enrich", return_value=False):
|
||||||
|
barrier = threading.Barrier(3)
|
||||||
|
|
||||||
|
def enqueue(root):
|
||||||
|
barrier.wait()
|
||||||
|
seeder.enqueue_enrich(roots=(root,), compute_hashes=False)
|
||||||
|
|
||||||
|
threads = [
|
||||||
|
threading.Thread(target=enqueue, args=(r,))
|
||||||
|
for r in ("models", "input", "output")
|
||||||
|
]
|
||||||
|
for t in threads:
|
||||||
|
t.start()
|
||||||
|
for t in threads:
|
||||||
|
t.join(timeout=5)
|
||||||
|
|
||||||
|
merged = set(seeder._pending_enrich["roots"])
|
||||||
|
assert merged == {"models", "input", "output"}
|
||||||
Loading…
Reference in New Issue
Block a user