fix(assets): merge duplicate scan specs

Co-authored-by: Amp <amp@ampcode.com>
Amp-Thread-ID: https://ampcode.com/threads/T-019ecf39-2e6f-747d-ae80-addba6b8e4f5
This commit is contained in:
Simon Pinfold 2026-07-02 07:11:36 +12:00
parent 04198cd192
commit ccc9387298
2 changed files with 156 additions and 0 deletions

View File

@ -134,6 +134,14 @@ def batch_insert_seed_assets(
for spec in specs: for spec in specs:
absolute_path = os.path.abspath(spec["abs_path"]) absolute_path = os.path.abspath(spec["abs_path"])
existing_asset_id = path_to_asset_id.get(absolute_path)
if existing_asset_id is not None:
existing_tags = asset_id_to_ref_data[existing_asset_id]["tags"]
asset_id_to_ref_data[existing_asset_id]["tags"] = list(
dict.fromkeys([*existing_tags, *spec["tags"]])
)
continue
asset_id = str(uuid.uuid4()) asset_id = str(uuid.uuid4())
reference_id = str(uuid.uuid4()) reference_id = str(uuid.uuid4())
absolute_path_list.append(absolute_path) absolute_path_list.append(absolute_path)

View File

@ -1,10 +1,13 @@
"""Tests for bulk ingest services.""" """Tests for bulk ingest services."""
from pathlib import Path from pathlib import Path
from unittest.mock import patch
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries import get_reference_tags
from app.assets.scanner import build_asset_specs
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
@ -101,6 +104,151 @@ class TestBatchInsertSeedAssets:
asset = session.query(Asset).filter_by(id=ref.asset_id).first() asset = session.query(Asset).filter_by(id=ref.asset_id).first()
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}" assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
def test_duplicate_paths_merge_tags_before_insert(
self, session: Session, temp_dir: Path
):
"""Overlapping model-folder registrations can emit the same path twice."""
file_path = temp_dir / "shared.safetensors"
file_path.write_bytes(b"shared model")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:checkpoints"],
"fname": "shared.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
{
"abs_path": str(file_path),
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:diffusion_models"],
"fname": "shared.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
assert result.won_paths == 1
refs = session.query(AssetReference).all()
assert len(refs) == 1
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
def test_duplicate_paths_are_merged_after_abspath_normalization(
self, session: Session, temp_dir: Path, monkeypatch
):
"""The scanner may emit equivalent paths with different spelling."""
file_path = temp_dir / "same-file.safetensors"
file_path.write_bytes(b"shared model")
monkeypatch.chdir(temp_dir)
relative_path = file_path.name
specs: list[SeedAssetSpec] = [
{
"abs_path": relative_path,
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:checkpoints"],
"fname": "same-file.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
{
"abs_path": str(file_path),
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:diffusion_models"],
"fname": "same-file.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
assert result.won_paths == 1
refs = session.query(AssetReference).all()
assert len(refs) == 1
assert refs[0].file_path == str(file_path)
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
def test_scanner_duplicate_shared_model_paths_keep_all_model_type_tags(
self, session: Session, temp_dir: Path
):
"""Shared extra model roots make scanner collection emit duplicate paths."""
shared_root = temp_dir / "shared"
input_dir = temp_dir / "input"
output_dir = temp_dir / "output"
temp_root = temp_dir / "temp"
for directory in (shared_root, input_dir, output_dir, temp_root):
directory.mkdir()
file_path = shared_root / "dual_use_model.safetensors"
file_path.write_bytes(b"shared model")
with (
patch("app.assets.services.path_utils.folder_paths") as mock_fp,
patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("checkpoints", [str(shared_root)]),
("diffusion_models", [str(shared_root)]),
],
),
):
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_root)
specs, tag_pool, skipped = build_asset_specs(
paths=[str(file_path), str(file_path)],
existing_paths=set(),
enable_metadata_extraction=False,
compute_hashes=False,
)
assert skipped == 0
assert len(specs) == 2
assert tag_pool == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
assert result.won_paths == 1
refs = session.query(AssetReference).all()
assert len(refs) == 1
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
class TestMetadataExtraction: class TestMetadataExtraction:
def test_extracts_mime_type_for_model_files(self, temp_dir: Path): def test_extracts_mime_type_for_model_files(self, temp_dir: Path):