mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-07-04 05:31:03 +08:00
Cover the behaviour that has no production change but is easy to regress: the extra-path asymmetry (loadable but no storage namespace), null loader_path persistence for orphan files, and the response reading the stored column with a compute fallback for un-backfilled rows.
320 lines
12 KiB
Python
320 lines
12 KiB
Python
"""Tests for bulk ingest services."""
|
|
|
|
import os
|
|
from pathlib import Path
|
|
from unittest.mock import patch
|
|
|
|
from sqlalchemy.orm import Session
|
|
|
|
from app.assets.database.models import Asset, AssetReference
|
|
from app.assets.database.queries import get_reference_tags
|
|
from app.assets.scanner import build_asset_specs
|
|
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
|
|
|
|
|
|
class TestBatchInsertSeedAssets:
|
|
def test_populates_mime_type_for_model_files(self, session: Session, temp_dir: Path):
|
|
"""Verify mime_type is stored in the Asset table for model files."""
|
|
file_path = temp_dir / "model.safetensors"
|
|
file_path.write_bytes(b"fake safetensors content")
|
|
|
|
specs: list[SeedAssetSpec] = [
|
|
{
|
|
"abs_path": str(file_path),
|
|
"size_bytes": 24,
|
|
"mtime_ns": 1234567890000000000,
|
|
"info_name": "Test Model",
|
|
"tags": ["models"],
|
|
"fname": "model.safetensors",
|
|
"metadata": None,
|
|
"hash": None,
|
|
"mime_type": "application/safetensors",
|
|
}
|
|
]
|
|
|
|
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
|
|
|
assert result.inserted_refs == 1
|
|
|
|
# Verify Asset has mime_type populated
|
|
assets = session.query(Asset).all()
|
|
assert len(assets) == 1
|
|
assert assets[0].mime_type == "application/safetensors"
|
|
|
|
def test_mime_type_none_when_not_provided(self, session: Session, temp_dir: Path):
|
|
"""Verify mime_type is None when not provided in spec."""
|
|
file_path = temp_dir / "unknown.bin"
|
|
file_path.write_bytes(b"binary data")
|
|
|
|
specs: list[SeedAssetSpec] = [
|
|
{
|
|
"abs_path": str(file_path),
|
|
"size_bytes": 11,
|
|
"mtime_ns": 1234567890000000000,
|
|
"info_name": "Unknown File",
|
|
"tags": [],
|
|
"fname": "unknown.bin",
|
|
"metadata": None,
|
|
"hash": None,
|
|
"mime_type": None,
|
|
}
|
|
]
|
|
|
|
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
|
|
|
assert result.inserted_refs == 1
|
|
|
|
assets = session.query(Asset).all()
|
|
assert len(assets) == 1
|
|
assert assets[0].mime_type is None
|
|
|
|
def test_various_model_mime_types(self, session: Session, temp_dir: Path):
|
|
"""Verify various model file types get correct mime_type."""
|
|
test_cases = [
|
|
("model.safetensors", "application/safetensors"),
|
|
("model.pt", "application/pytorch"),
|
|
("model.ckpt", "application/pickle"),
|
|
("model.gguf", "application/gguf"),
|
|
]
|
|
|
|
specs: list[SeedAssetSpec] = []
|
|
for filename, mime_type in test_cases:
|
|
file_path = temp_dir / filename
|
|
file_path.write_bytes(b"content")
|
|
specs.append(
|
|
{
|
|
"abs_path": str(file_path),
|
|
"size_bytes": 7,
|
|
"mtime_ns": 1234567890000000000,
|
|
"info_name": filename,
|
|
"tags": [],
|
|
"fname": filename,
|
|
"metadata": None,
|
|
"hash": None,
|
|
"mime_type": mime_type,
|
|
}
|
|
)
|
|
|
|
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
|
|
|
assert result.inserted_refs == len(test_cases)
|
|
|
|
for filename, expected_mime in test_cases:
|
|
ref = session.query(AssetReference).filter_by(name=filename).first()
|
|
assert ref is not None
|
|
asset = session.query(Asset).filter_by(id=ref.asset_id).first()
|
|
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
|
|
|
|
def test_duplicate_paths_merge_tags_before_insert(
|
|
self, session: Session, temp_dir: Path
|
|
):
|
|
"""Overlapping model-folder registrations can emit the same path twice."""
|
|
file_path = temp_dir / "shared.safetensors"
|
|
file_path.write_bytes(b"shared model")
|
|
|
|
specs: list[SeedAssetSpec] = [
|
|
{
|
|
"abs_path": str(file_path),
|
|
"size_bytes": 12,
|
|
"mtime_ns": 1234567890000000000,
|
|
"info_name": "Shared Model",
|
|
"tags": ["models", "model_type:checkpoints"],
|
|
"fname": "shared.safetensors",
|
|
"metadata": None,
|
|
"hash": None,
|
|
"mime_type": "application/safetensors",
|
|
},
|
|
{
|
|
"abs_path": str(file_path),
|
|
"size_bytes": 12,
|
|
"mtime_ns": 1234567890000000000,
|
|
"info_name": "Shared Model",
|
|
"tags": ["models", "model_type:diffusion_models"],
|
|
"fname": "shared.safetensors",
|
|
"metadata": None,
|
|
"hash": None,
|
|
"mime_type": "application/safetensors",
|
|
},
|
|
]
|
|
|
|
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
|
|
|
assert result.inserted_refs == 1
|
|
assert result.won_paths == 1
|
|
refs = session.query(AssetReference).all()
|
|
assert len(refs) == 1
|
|
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
|
|
"models",
|
|
"model_type:checkpoints",
|
|
"model_type:diffusion_models",
|
|
}
|
|
|
|
def test_duplicate_paths_are_merged_after_abspath_normalization(
|
|
self, session: Session, temp_dir: Path, monkeypatch
|
|
):
|
|
"""The scanner may emit equivalent paths with different spelling."""
|
|
file_path = temp_dir / "same-file.safetensors"
|
|
file_path.write_bytes(b"shared model")
|
|
monkeypatch.chdir(temp_dir)
|
|
relative_path = file_path.name
|
|
absolute_path = os.path.abspath(relative_path)
|
|
|
|
specs: list[SeedAssetSpec] = [
|
|
{
|
|
"abs_path": relative_path,
|
|
"size_bytes": 12,
|
|
"mtime_ns": 1234567890000000000,
|
|
"info_name": "Shared Model",
|
|
"tags": ["models", "model_type:checkpoints"],
|
|
"fname": "same-file.safetensors",
|
|
"metadata": None,
|
|
"hash": None,
|
|
"mime_type": "application/safetensors",
|
|
},
|
|
{
|
|
"abs_path": absolute_path,
|
|
"size_bytes": 12,
|
|
"mtime_ns": 1234567890000000000,
|
|
"info_name": "Shared Model",
|
|
"tags": ["models", "model_type:diffusion_models"],
|
|
"fname": "same-file.safetensors",
|
|
"metadata": None,
|
|
"hash": None,
|
|
"mime_type": "application/safetensors",
|
|
},
|
|
]
|
|
|
|
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
|
|
|
assert result.inserted_refs == 1
|
|
assert result.won_paths == 1
|
|
refs = session.query(AssetReference).all()
|
|
assert len(refs) == 1
|
|
assert refs[0].file_path == absolute_path
|
|
# loader_path is persisted from the spec's fname (compute_loader_path).
|
|
assert refs[0].loader_path == "same-file.safetensors"
|
|
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
|
|
"models",
|
|
"model_type:checkpoints",
|
|
"model_type:diffusion_models",
|
|
}
|
|
|
|
def test_scanner_duplicate_shared_model_paths_keep_all_model_type_tags(
|
|
self, session: Session, temp_dir: Path
|
|
):
|
|
"""Shared extra model roots make scanner collection emit duplicate paths."""
|
|
shared_root = temp_dir / "shared"
|
|
input_dir = temp_dir / "input"
|
|
output_dir = temp_dir / "output"
|
|
temp_root = temp_dir / "temp"
|
|
for directory in (shared_root, input_dir, output_dir, temp_root):
|
|
directory.mkdir()
|
|
file_path = shared_root / "dual_use_model.safetensors"
|
|
file_path.write_bytes(b"shared model")
|
|
|
|
with (
|
|
patch("app.assets.services.path_utils.folder_paths") as mock_fp,
|
|
patch(
|
|
"app.assets.services.path_utils.get_comfy_models_folders",
|
|
return_value=[
|
|
("checkpoints", [str(shared_root)], {".safetensors"}),
|
|
("diffusion_models", [str(shared_root)], {".safetensors"}),
|
|
],
|
|
),
|
|
):
|
|
mock_fp.get_input_directory.return_value = str(input_dir)
|
|
mock_fp.get_output_directory.return_value = str(output_dir)
|
|
mock_fp.get_temp_directory.return_value = str(temp_root)
|
|
|
|
specs, tag_pool, skipped = build_asset_specs(
|
|
paths=[str(file_path), str(file_path)],
|
|
existing_paths=set(),
|
|
enable_metadata_extraction=False,
|
|
compute_hashes=False,
|
|
)
|
|
|
|
assert skipped == 0
|
|
assert len(specs) == 2
|
|
assert tag_pool == {
|
|
"models",
|
|
"model_type:checkpoints",
|
|
"model_type:diffusion_models",
|
|
}
|
|
|
|
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
|
|
|
assert result.inserted_refs == 1
|
|
assert result.won_paths == 1
|
|
refs = session.query(AssetReference).all()
|
|
assert len(refs) == 1
|
|
assert set(get_reference_tags(session, reference_id=refs[0].id)) == {
|
|
"models",
|
|
"model_type:checkpoints",
|
|
"model_type:diffusion_models",
|
|
}
|
|
|
|
def test_loader_path_persisted_as_null_when_fname_is_none(
|
|
self, session: Session, temp_dir: Path
|
|
):
|
|
"""A file with no in-root loader path (fname=None, e.g. an orphan under
|
|
models_root) persists loader_path as NULL rather than a synthesized value."""
|
|
file_path = temp_dir / "orphan.bin"
|
|
file_path.write_bytes(b"x")
|
|
|
|
specs: list[SeedAssetSpec] = [
|
|
{
|
|
"abs_path": str(file_path),
|
|
"size_bytes": 1,
|
|
"mtime_ns": 1234567890000000000,
|
|
"info_name": "orphan.bin",
|
|
"tags": [],
|
|
"fname": None,
|
|
"metadata": None,
|
|
"hash": None,
|
|
"mime_type": None,
|
|
}
|
|
]
|
|
|
|
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
|
|
|
assert result.inserted_refs == 1
|
|
refs = session.query(AssetReference).all()
|
|
assert len(refs) == 1
|
|
assert refs[0].file_path == str(file_path)
|
|
assert refs[0].loader_path is None
|
|
|
|
|
|
class TestMetadataExtraction:
|
|
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
|
|
"""Verify metadata extraction returns correct mime_type for model files."""
|
|
from app.assets.services.metadata_extract import extract_file_metadata
|
|
|
|
file_path = temp_dir / "model.safetensors"
|
|
file_path.write_bytes(b"fake safetensors content")
|
|
|
|
meta = extract_file_metadata(str(file_path))
|
|
|
|
assert meta.content_type == "application/safetensors"
|
|
|
|
def test_mime_type_for_various_model_formats(self, temp_dir: Path):
|
|
"""Verify various model file types get correct mime_type from metadata."""
|
|
from app.assets.services.metadata_extract import extract_file_metadata
|
|
|
|
test_cases = [
|
|
("model.safetensors", "application/safetensors"),
|
|
("model.sft", "application/safetensors"),
|
|
("model.pt", "application/pytorch"),
|
|
("model.pth", "application/pytorch"),
|
|
("model.ckpt", "application/pickle"),
|
|
("model.pkl", "application/pickle"),
|
|
("model.gguf", "application/gguf"),
|
|
]
|
|
|
|
for filename, expected_mime in test_cases:
|
|
file_path = temp_dir / filename
|
|
file_path.write_bytes(b"content")
|
|
|
|
meta = extract_file_metadata(str(file_path))
|
|
|
|
assert meta.content_type == expected_mime, f"Expected {expected_mime} for {filename}, got {meta.content_type}"
|