mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-07 02:07:32 +08:00
Populate mime_type for assets in scanner and API paths
- Add custom MIME type registrations for model files (.safetensors, .pt, .ckpt, .gguf, .yaml) - Pass mime_type through SeedAssetSpec to bulk_ingest - Re-register types before use since server.py mimetypes.init() resets them - Add tests for bulk ingest mime_type handling Amp-Thread-ID: https://ampcode.com/threads/T-019c3626-c6ad-7139-a570-62da4e656a1a Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
b378e69aed
commit
9016593586
@ -320,6 +320,9 @@ def build_asset_specs(
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("Failed to hash %s: %s", abs_p, e)
|
logging.warning("Failed to hash %s: %s", abs_p, e)
|
||||||
|
|
||||||
|
mime_type = metadata.content_type if metadata else None
|
||||||
|
if mime_type is None:
|
||||||
|
print(f"[build_asset_specs] no mime_type for {abs_p} (metadata={metadata is not None})")
|
||||||
specs.append(
|
specs.append(
|
||||||
{
|
{
|
||||||
"abs_path": abs_p,
|
"abs_path": abs_p,
|
||||||
@ -330,6 +333,7 @@ def build_asset_specs(
|
|||||||
"fname": rel_fname,
|
"fname": rel_fname,
|
||||||
"metadata": metadata,
|
"metadata": metadata,
|
||||||
"hash": asset_hash,
|
"hash": asset_hash,
|
||||||
|
"mime_type": mime_type,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
tag_pool.update(tags)
|
tag_pool.update(tags)
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -38,6 +39,7 @@ class SeedAssetSpec(TypedDict):
|
|||||||
fname: str
|
fname: str
|
||||||
metadata: ExtractedMetadata | None
|
metadata: ExtractedMetadata | None
|
||||||
hash: str | None
|
hash: str | None
|
||||||
|
mime_type: str | None
|
||||||
|
|
||||||
|
|
||||||
class AssetRow(TypedDict):
|
class AssetRow(TypedDict):
|
||||||
@ -162,12 +164,15 @@ def batch_insert_seed_assets(
|
|||||||
absolute_path_list.append(absolute_path)
|
absolute_path_list.append(absolute_path)
|
||||||
path_to_asset_id[absolute_path] = asset_id
|
path_to_asset_id[absolute_path] = asset_id
|
||||||
|
|
||||||
|
mime_type = spec.get("mime_type")
|
||||||
|
if mime_type is None:
|
||||||
|
logging.info("batch_insert_seed_assets: no mime_type for %s", absolute_path)
|
||||||
asset_rows.append(
|
asset_rows.append(
|
||||||
{
|
{
|
||||||
"id": asset_id,
|
"id": asset_id,
|
||||||
"hash": spec.get("hash"),
|
"hash": spec.get("hash"),
|
||||||
"size_bytes": spec["size_bytes"],
|
"size_bytes": spec["size_bytes"],
|
||||||
"mime_type": None,
|
"mime_type": mime_type,
|
||||||
"created_at": current_time,
|
"created_at": current_time,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|||||||
@ -20,6 +20,31 @@ SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
|
|||||||
# Maximum safetensors header size to read (8MB)
|
# Maximum safetensors header size to read (8MB)
|
||||||
MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024
|
MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024
|
||||||
|
|
||||||
|
def _register_custom_mime_types():
|
||||||
|
"""Register custom MIME types for model and config files.
|
||||||
|
|
||||||
|
Called before each use because mimetypes.init() in server.py resets the database.
|
||||||
|
Uses a quick check to avoid redundant registrations.
|
||||||
|
"""
|
||||||
|
# Quick check if already registered (avoids redundant add_type calls)
|
||||||
|
test_result, _ = mimetypes.guess_type("test.safetensors")
|
||||||
|
if test_result == "application/safetensors":
|
||||||
|
return
|
||||||
|
|
||||||
|
mimetypes.add_type("application/safetensors", ".safetensors")
|
||||||
|
mimetypes.add_type("application/safetensors", ".sft")
|
||||||
|
mimetypes.add_type("application/pytorch", ".pt")
|
||||||
|
mimetypes.add_type("application/pytorch", ".pth")
|
||||||
|
mimetypes.add_type("application/pickle", ".ckpt")
|
||||||
|
mimetypes.add_type("application/pickle", ".pkl")
|
||||||
|
mimetypes.add_type("application/gguf", ".gguf")
|
||||||
|
mimetypes.add_type("application/yaml", ".yaml")
|
||||||
|
mimetypes.add_type("application/yaml", ".yml")
|
||||||
|
|
||||||
|
|
||||||
|
# Register custom types at module load
|
||||||
|
_register_custom_mime_types()
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ExtractedMetadata:
|
class ExtractedMetadata:
|
||||||
@ -284,9 +309,12 @@ def extract_file_metadata(
|
|||||||
_, ext = os.path.splitext(abs_path)
|
_, ext = os.path.splitext(abs_path)
|
||||||
meta.format = ext.lstrip(".").lower() if ext else ""
|
meta.format = ext.lstrip(".").lower() if ext else ""
|
||||||
|
|
||||||
# MIME type guess
|
# MIME type guess (re-register in case mimetypes.init() was called elsewhere)
|
||||||
|
_register_custom_mime_types()
|
||||||
mime_type, _ = mimetypes.guess_type(abs_path)
|
mime_type, _ = mimetypes.guess_type(abs_path)
|
||||||
meta.content_type = mime_type
|
meta.content_type = mime_type
|
||||||
|
if mime_type is None:
|
||||||
|
print(f"[extract_file_metadata] No mime_type for {abs_path}")
|
||||||
|
|
||||||
# Size from stat
|
# Size from stat
|
||||||
if stat_result is None:
|
if stat_result is None:
|
||||||
|
|||||||
139
tests-unit/assets_test/services/test_bulk_ingest.py
Normal file
139
tests-unit/assets_test/services/test_bulk_ingest.py
Normal file
@ -0,0 +1,139 @@
|
|||||||
|
"""Tests for bulk ingest services."""
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.assets.database.models import Asset
|
||||||
|
from app.assets.services.bulk_ingest import SeedAssetSpec, batch_insert_seed_assets
|
||||||
|
|
||||||
|
|
||||||
|
class TestBatchInsertSeedAssets:
|
||||||
|
def test_populates_mime_type_for_model_files(self, session: Session, temp_dir: Path):
|
||||||
|
"""Verify mime_type is stored in the Asset table for model files."""
|
||||||
|
file_path = temp_dir / "model.safetensors"
|
||||||
|
file_path.write_bytes(b"fake safetensors content")
|
||||||
|
|
||||||
|
specs: list[SeedAssetSpec] = [
|
||||||
|
{
|
||||||
|
"abs_path": str(file_path),
|
||||||
|
"size_bytes": 24,
|
||||||
|
"mtime_ns": 1234567890000000000,
|
||||||
|
"info_name": "Test Model",
|
||||||
|
"tags": ["models"],
|
||||||
|
"fname": "model.safetensors",
|
||||||
|
"metadata": None,
|
||||||
|
"hash": None,
|
||||||
|
"mime_type": "application/safetensors",
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||||
|
|
||||||
|
assert result.inserted_infos == 1
|
||||||
|
|
||||||
|
# Verify Asset has mime_type populated
|
||||||
|
assets = session.query(Asset).all()
|
||||||
|
assert len(assets) == 1
|
||||||
|
assert assets[0].mime_type == "application/safetensors"
|
||||||
|
|
||||||
|
def test_mime_type_none_when_not_provided(self, session: Session, temp_dir: Path):
|
||||||
|
"""Verify mime_type is None when not provided in spec."""
|
||||||
|
file_path = temp_dir / "unknown.bin"
|
||||||
|
file_path.write_bytes(b"binary data")
|
||||||
|
|
||||||
|
specs: list[SeedAssetSpec] = [
|
||||||
|
{
|
||||||
|
"abs_path": str(file_path),
|
||||||
|
"size_bytes": 11,
|
||||||
|
"mtime_ns": 1234567890000000000,
|
||||||
|
"info_name": "Unknown File",
|
||||||
|
"tags": [],
|
||||||
|
"fname": "unknown.bin",
|
||||||
|
"metadata": None,
|
||||||
|
"hash": None,
|
||||||
|
"mime_type": None,
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||||
|
|
||||||
|
assert result.inserted_infos == 1
|
||||||
|
|
||||||
|
assets = session.query(Asset).all()
|
||||||
|
assert len(assets) == 1
|
||||||
|
assert assets[0].mime_type is None
|
||||||
|
|
||||||
|
def test_various_model_mime_types(self, session: Session, temp_dir: Path):
|
||||||
|
"""Verify various model file types get correct mime_type."""
|
||||||
|
test_cases = [
|
||||||
|
("model.safetensors", "application/safetensors"),
|
||||||
|
("model.pt", "application/pytorch"),
|
||||||
|
("model.ckpt", "application/pickle"),
|
||||||
|
("model.gguf", "application/gguf"),
|
||||||
|
]
|
||||||
|
|
||||||
|
specs: list[SeedAssetSpec] = []
|
||||||
|
for filename, mime_type in test_cases:
|
||||||
|
file_path = temp_dir / filename
|
||||||
|
file_path.write_bytes(b"content")
|
||||||
|
specs.append(
|
||||||
|
{
|
||||||
|
"abs_path": str(file_path),
|
||||||
|
"size_bytes": 7,
|
||||||
|
"mtime_ns": 1234567890000000000,
|
||||||
|
"info_name": filename,
|
||||||
|
"tags": [],
|
||||||
|
"fname": filename,
|
||||||
|
"metadata": None,
|
||||||
|
"hash": None,
|
||||||
|
"mime_type": mime_type,
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
result = batch_insert_seed_assets(session, specs=specs, owner_id="")
|
||||||
|
|
||||||
|
assert result.inserted_infos == len(test_cases)
|
||||||
|
|
||||||
|
for filename, expected_mime in test_cases:
|
||||||
|
from app.assets.database.models import AssetInfo
|
||||||
|
info = session.query(AssetInfo).filter_by(name=filename).first()
|
||||||
|
assert info is not None
|
||||||
|
asset = session.query(Asset).filter_by(id=info.asset_id).first()
|
||||||
|
assert asset.mime_type == expected_mime, f"Expected {expected_mime} for {filename}, got {asset.mime_type}"
|
||||||
|
|
||||||
|
|
||||||
|
class TestMetadataExtraction:
|
||||||
|
def test_extracts_mime_type_for_model_files(self, temp_dir: Path):
|
||||||
|
"""Verify metadata extraction returns correct mime_type for model files."""
|
||||||
|
from app.assets.services.metadata_extract import extract_file_metadata
|
||||||
|
|
||||||
|
file_path = temp_dir / "model.safetensors"
|
||||||
|
file_path.write_bytes(b"fake safetensors content")
|
||||||
|
|
||||||
|
meta = extract_file_metadata(str(file_path))
|
||||||
|
|
||||||
|
assert meta.content_type == "application/safetensors"
|
||||||
|
|
||||||
|
def test_mime_type_for_various_model_formats(self, temp_dir: Path):
|
||||||
|
"""Verify various model file types get correct mime_type from metadata."""
|
||||||
|
from app.assets.services.metadata_extract import extract_file_metadata
|
||||||
|
|
||||||
|
test_cases = [
|
||||||
|
("model.safetensors", "application/safetensors"),
|
||||||
|
("model.sft", "application/safetensors"),
|
||||||
|
("model.pt", "application/pytorch"),
|
||||||
|
("model.pth", "application/pytorch"),
|
||||||
|
("model.ckpt", "application/pickle"),
|
||||||
|
("model.pkl", "application/pickle"),
|
||||||
|
("model.gguf", "application/gguf"),
|
||||||
|
]
|
||||||
|
|
||||||
|
for filename, expected_mime in test_cases:
|
||||||
|
file_path = temp_dir / filename
|
||||||
|
file_path.write_bytes(b"content")
|
||||||
|
|
||||||
|
meta = extract_file_metadata(str(file_path))
|
||||||
|
|
||||||
|
assert meta.content_type == expected_mime, f"Expected {expected_mime} for {filename}, got {meta.content_type}"
|
||||||
Loading…
Reference in New Issue
Block a user