ComfyUI/tests-unit/assets_test/services/test_bulk_ingest.py
Simon Pinfold ca5adea2e3
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
test(assets): make duplicate path normalization portable
Amp-Thread-ID: https://ampcode.com/threads/T-019ecf39-2e6f-747d-ae80-addba6b8e4f5
Co-authored-by: Amp <amp@ampcode.com>
2026-07-02 08:24:30 +12:00

288 lines
11 KiB
Python

"""Tests for bulk ingest services."""
import os
from pathlib import Path
from unittest.mock import patch
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries import get_reference_tags
from app.assets.scanner import build_asset_specs
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
class TestBatchInsertSeedAssets:
def test_populates_mime_type_for_model_files(self, session: Session, temp_dir: Path):
"""Verify mime_type is stored in the Asset table for model files."""
file_path = temp_dir / "model.safetensors"
file_path.write_bytes(b"fake safetensors content")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 24,
"mtime_ns": 1234567890000000000,
"info_name": "Test Model",
"tags": ["models"],
"fname": "model.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
}
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
# Verify Asset has mime_type populated
assets = session.query(Asset).all()
assert len(assets) == 1
assert assets[0].mime_type == "application/safetensors"
def test_mime_type_none_when_not_provided(self, session: Session, temp_dir: Path):
"""Verify mime_type is None when not provided in spec."""
file_path = temp_dir / "unknown.bin"
file_path.write_bytes(b"binary data")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 11,
"mtime_ns": 1234567890000000000,
"info_name": "Unknown File",
"tags": [],
"fname": "unknown.bin",
"metadata": None,
"hash": None,
"mime_type": None,
}
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
assets = session.query(Asset).all()
assert len(assets) == 1
assert assets[0].mime_type is None
def test_various_model_mime_types(self, session: Session, temp_dir: Path):
"""Verify various model file types get correct mime_type."""
test_cases = [
("model.safetensors", "application/safetensors"),
("model.pt", "application/pytorch"),
("model.ckpt", "application/pickle"),
("model.gguf", "application/gguf"),
]
specs: list[SeedAssetSpec] = []
for filename, mime_type in test_cases:
file_path = temp_dir / filename
file_path.write_bytes(b"content")
specs.append(
{
"abs_path": str(file_path),
"size_bytes": 7,
"mtime_ns": 1234567890000000000,
"info_name": filename,
"tags": [],
"fname": filename,
"metadata": None,
"hash": None,
"mime_type": mime_type,
}
)
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == len(test_cases)
for filename, expected_mime in test_cases:
ref = session.query(AssetReference).filter_by(name=filename).first()
assert ref is not None
asset = session.query(Asset).filter_by(id=ref.asset_id).first()
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
def test_duplicate_paths_merge_tags_before_insert(
self, session: Session, temp_dir: Path
):
"""Overlapping model-folder registrations can emit the same path twice."""
file_path = temp_dir / "shared.safetensors"
file_path.write_bytes(b"shared model")
specs: list[SeedAssetSpec] = [
{
"abs_path": str(file_path),
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:checkpoints"],
"fname": "shared.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
{
"abs_path": str(file_path),
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:diffusion_models"],
"fname": "shared.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
assert result.won_paths == 1
refs = session.query(AssetReference).all()
assert len(refs) == 1
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
def test_duplicate_paths_are_merged_after_abspath_normalization(
self, session: Session, temp_dir: Path, monkeypatch
):
"""The scanner may emit equivalent paths with different spelling."""
file_path = temp_dir / "same-file.safetensors"
file_path.write_bytes(b"shared model")
monkeypatch.chdir(temp_dir)
relative_path = file_path.name
absolute_path = os.path.abspath(relative_path)
specs: list[SeedAssetSpec] = [
{
"abs_path": relative_path,
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:checkpoints"],
"fname": "same-file.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
{
"abs_path": absolute_path,
"size_bytes": 12,
"mtime_ns": 1234567890000000000,
"info_name": "Shared Model",
"tags": ["models", "model_type:diffusion_models"],
"fname": "same-file.safetensors",
"metadata": None,
"hash": None,
"mime_type": "application/safetensors",
},
]
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
assert result.won_paths == 1
refs = session.query(AssetReference).all()
assert len(refs) == 1
assert refs[0].file_path == absolute_path
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
def test_scanner_duplicate_shared_model_paths_keep_all_model_type_tags(
self, session: Session, temp_dir: Path
):
"""Shared extra model roots make scanner collection emit duplicate paths."""
shared_root = temp_dir / "shared"
input_dir = temp_dir / "input"
output_dir = temp_dir / "output"
temp_root = temp_dir / "temp"
for directory in (shared_root, input_dir, output_dir, temp_root):
directory.mkdir()
file_path = shared_root / "dual_use_model.safetensors"
file_path.write_bytes(b"shared model")
with (
patch("app.assets.services.path_utils.folder_paths") as mock_fp,
patch(
"app.assets.services.path_utils.get_comfy_models_folders",
return_value=[
("checkpoints", [str(shared_root)]),
("diffusion_models", [str(shared_root)]),
],
),
):
mock_fp.get_input_directory.return_value = str(input_dir)
mock_fp.get_output_directory.return_value = str(output_dir)
mock_fp.get_temp_directory.return_value = str(temp_root)
specs, tag_pool, skipped = build_asset_specs(
paths=[str(file_path), str(file_path)],
existing_paths=set(),
enable_metadata_extraction=False,
compute_hashes=False,
)
assert skipped == 0
assert len(specs) == 2
assert tag_pool == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
assert result.inserted_refs == 1
assert result.won_paths == 1
refs = session.query(AssetReference).all()
assert len(refs) == 1
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
"models",
"model_type:checkpoints",
"model_type:diffusion_models",
}
class TestMetadataExtraction:
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
"""Verify metadata extraction returns correct mime_type for model files."""
from app.assets.services.metadata_extract import extract_file_metadata
file_path = temp_dir / "model.safetensors"
file_path.write_bytes(b"fake safetensors content")
meta = extract_file_metadata(str(file_path))
assert meta.content_type == "application/safetensors"
def test_mime_type_for_various_model_formats(self, temp_dir: Path):
"""Verify various model file types get correct mime_type from metadata."""
from app.assets.services.metadata_extract import extract_file_metadata
test_cases = [
("model.safetensors", "application/safetensors"),
("model.sft", "application/safetensors"),
("model.pt", "application/pytorch"),
("model.pth", "application/pytorch"),
("model.ckpt", "application/pickle"),
("model.pkl", "application/pickle"),
("model.gguf", "application/gguf"),
]
for filename, expected_mime in test_cases:
file_path = temp_dir / filename
file_path.write_bytes(b"content")
meta = extract_file_metadata(str(file_path))
assert meta.content_type == expected_mime, f"Expected {expected_mime} for {filename}, got {meta.content_type}"