refactor: add explicit types to asset service functions

- Add typed result dataclasses: IngestResult, AddTagsResult,
  RemoveTagsResult, SetTagsResult, TagUsage
- Add UserMetadata type alias for user_metadata parameters
- Type helper functions with Session parameters
- Use TypedDicts at query layer to avoid circular imports
- Update manager.py and tests to use attribute access

Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
Luke Mino-Altherr 2026-02-03 20:32:14 -08:00
parent 37ecc5b663
commit 9290e26e9f
10 changed files with 184 additions and 125 deletions

View File

@ -37,6 +37,9 @@ from app.assets.database.queries.cache_state import (
upsert_cache_state,
)
from app.assets.database.queries.tags import (
AddTagsDict,
RemoveTagsDict,
SetTagsDict,
add_missing_tag_for_asset_id,
add_tags_to_asset_info,
bulk_insert_tags_and_meta,
@ -90,4 +93,7 @@ __all__ = [
"remove_missing_tag_for_asset_id",
"list_tags_with_usage",
"bulk_insert_tags_and_meta",
"AddTagsDict",
"RemoveTagsDict",
"SetTagsDict",
]

View File

@ -1,4 +1,4 @@
from typing import Iterable, Sequence
from typing import Iterable, Sequence, TypedDict
import sqlalchemy as sa
from sqlalchemy import delete, func, select
@ -9,6 +9,24 @@ from sqlalchemy.orm import Session
from app.assets.database.models import AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
class AddTagsDict(TypedDict):
added: list[str]
already_present: list[str]
total_tags: list[str]
class RemoveTagsDict(TypedDict):
removed: list[str]
not_present: list[str]
total_tags: list[str]
class SetTagsDict(TypedDict):
added: list[str]
removed: list[str]
total: list[str]
MAX_BIND_PARAMS = 800
@ -60,7 +78,7 @@ def set_asset_info_tags(
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
) -> SetTagsDict:
desired = normalize_tags(tags)
current = set(
@ -96,8 +114,8 @@ def add_tags_to_asset_info(
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row = None,
) -> dict:
asset_info_row: AssetInfo | None = None,
) -> AddTagsDict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
@ -153,7 +171,7 @@ def remove_tags_from_asset_info(
session: Session,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
) -> RemoveTagsDict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")

View File

@ -268,7 +268,7 @@ def upload_asset_from_temp_path(
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
info_id = result.asset_info_id
if not info_id:
raise RuntimeError("failed to create asset metadata")
@ -290,7 +290,7 @@ def upload_asset_from_temp_path(
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
created_new=result.asset_created,
)
@ -479,13 +479,17 @@ def add_tags_to_asset(
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
data = apply_tags(
result = apply_tags(
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
owner_id=owner_id,
)
return schemas_out.TagsAdd(**data)
return schemas_out.TagsAdd(
added=result.added,
already_present=result.already_present,
total_tags=result.total_tags,
)
def remove_tags_from_asset(
@ -493,12 +497,16 @@ def remove_tags_from_asset(
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
data = remove_tags(
result = remove_tags(
asset_info_id=asset_info_id,
tags=tags,
owner_id=owner_id,
)
return schemas_out.TagsRemove(**data)
return schemas_out.TagsRemove(
removed=result.removed,
not_present=result.not_present,
total_tags=result.total_tags,
)
def list_tags(

View File

@ -8,6 +8,18 @@ from app.assets.services.ingest import (
ingest_file_from_path,
register_existing_asset,
)
from app.assets.services.schemas import (
AddTagsResult,
AssetData,
AssetDetailResult,
AssetInfoData,
IngestResult,
RegisterAssetResult,
RemoveTagsResult,
SetTagsResult,
TagUsage,
UserMetadata,
)
from app.assets.services.tagging import (
apply_tags,
list_tags,
@ -24,4 +36,14 @@ __all__ = [
"apply_tags",
"remove_tags",
"list_tags",
"AddTagsResult",
"AssetData",
"AssetDetailResult",
"AssetInfoData",
"IngestResult",
"RegisterAssetResult",
"RemoveTagsResult",
"SetTagsResult",
"TagUsage",
"UserMetadata",
]

View File

@ -2,6 +2,8 @@ import contextlib
import os
from typing import Sequence
from sqlalchemy.orm import Session
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_info_exists_for_asset_id,
@ -19,6 +21,7 @@ from app.assets.helpers import select_best_live_path
from app.assets.services.path_utils import compute_relative_filename
from app.assets.services.schemas import (
AssetDetailResult,
UserMetadata,
extract_asset_data,
extract_info_data,
)
@ -29,10 +32,6 @@ def get_asset_detail(
asset_info_id: str,
owner_id: str = "",
) -> AssetDetailResult | None:
"""
Fetch full asset details including tags.
Returns AssetDetailResult or None if not found.
"""
with create_session() as session:
result = fetch_asset_info_asset_and_tags(
session,
@ -54,14 +53,10 @@ def update_asset_metadata(
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: dict | None = None,
user_metadata: UserMetadata = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetDetailResult:
"""
Update name, tags, and/or metadata on an AssetInfo.
Returns AssetDetailResult with updated data.
"""
with create_session() as session:
info = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info:
@ -128,11 +123,6 @@ def delete_asset_reference(
owner_id: str,
delete_content_if_orphan: bool = True,
) -> bool:
"""
Delete an AssetInfo reference.
If delete_content_if_orphan is True and no other AssetInfos reference the asset,
also delete the Asset and its cached files.
"""
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
@ -175,10 +165,6 @@ def set_asset_preview(
preview_asset_id: str | None = None,
owner_id: str = "",
) -> AssetDetailResult:
"""
Set or clear preview_id on an AssetInfo.
Returns AssetDetailResult with updated data.
"""
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
@ -209,7 +195,6 @@ def set_asset_preview(
return detail
def _compute_filename_for_asset(session, asset_id: str) -> str | None:
"""Compute the relative filename for an asset from its cache states."""
def _compute_filename_for_asset(session: Session, asset_id: str) -> str | None:
primary_path = select_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset_id))
return compute_relative_filename(primary_path) if primary_path else None

View File

@ -3,8 +3,9 @@ import os
from typing import Sequence
from sqlalchemy import select
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, Tag
from app.assets.database.models import Asset, AssetInfo, Tag
from app.assets.database.queries import (
add_tags_to_asset_info,
get_asset_by_hash,
@ -21,7 +22,9 @@ from app.assets.database.queries import (
from app.assets.helpers import normalize_tags, select_best_live_path
from app.assets.services.path_utils import compute_relative_filename
from app.assets.services.schemas import (
IngestResult,
RegisterAssetResult,
UserMetadata,
extract_asset_data,
extract_info_data,
)
@ -37,41 +40,30 @@ def ingest_file_from_path(
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
user_metadata: UserMetadata = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
) -> IngestResult:
locator = os.path.abspath(abs_path)
out: dict = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
asset_created = False
asset_updated = False
state_created = False
state_updated = False
asset_info_id: str | None = None
with create_session() as session:
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
asset, created, updated = upsert_asset(
asset, asset_created, asset_updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=size_bytes,
mime_type=mime_type,
)
out["asset_created"] = created
out["asset_updated"] = updated
state_created, state_updated = upsert_cache_state(
session,
@ -79,8 +71,6 @@ def ingest_file_from_path(
file_path=locator,
mtime_ns=mtime_ns,
)
out["state_created"] = state_created
out["state_updated"] = state_updated
if info_name:
info, info_created = get_or_create_asset_info(
@ -91,27 +81,27 @@ def ingest_file_from_path(
preview_id=preview_id,
)
if info_created:
out["asset_info_id"] = info.id
asset_info_id = info.id
else:
update_asset_info_timestamps(session, asset_info=info, preview_id=preview_id)
out["asset_info_id"] = info.id
asset_info_id = info.id
norm = normalize_tags(list(tags))
if norm and out["asset_info_id"]:
if norm and asset_info_id:
if require_existing_tags:
_validate_tags_exist(session, norm)
add_tags_to_asset_info(
session,
asset_info_id=out["asset_info_id"],
asset_info_id=asset_info_id,
tags=norm,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
if out["asset_info_id"]:
if asset_info_id:
_update_metadata_with_filename(
session,
asset_info_id=out["asset_info_id"],
asset_info_id=asset_info_id,
asset_id=asset.id,
info=info,
user_metadata=user_metadata,
@ -124,22 +114,23 @@ def ingest_file_from_path(
session.commit()
return out
return IngestResult(
asset_created=asset_created,
asset_updated=asset_updated,
state_created=state_created,
state_updated=state_updated,
asset_info_id=asset_info_id,
)
def register_existing_asset(
asset_hash: str,
name: str,
user_metadata: dict | None = None,
user_metadata: UserMetadata = None,
tags: list[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> RegisterAssetResult:
"""
Create or return existing AssetInfo for an asset that already exists by hash.
Returns RegisterAssetResult with plain data.
Raises ValueError if hash not found.
"""
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
@ -197,8 +188,7 @@ def register_existing_asset(
return result
def _validate_tags_exist(session, tags: list[str]) -> None:
"""Raise ValueError if any tags don't exist."""
def _validate_tags_exist(session: Session, tags: list[str]) -> None:
existing_tag_names = set(
name for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
@ -207,20 +197,18 @@ def _validate_tags_exist(session, tags: list[str]) -> None:
raise ValueError(f"Unknown tags: {missing}")
def _compute_filename_for_asset(session, asset_id: str) -> str | None:
"""Compute the relative filename for an asset from its cache states."""
def _compute_filename_for_asset(session: Session, asset_id: str) -> str | None:
primary_path = select_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset_id))
return compute_relative_filename(primary_path) if primary_path else None
def _update_metadata_with_filename(
session,
session: Session,
asset_info_id: str,
asset_id: str,
info,
user_metadata: dict | None,
info: AssetInfo,
user_metadata: UserMetadata,
) -> None:
"""Update metadata projection with computed filename."""
computed_filename = _compute_filename_for_asset(session, asset_id)
current_meta = info.user_metadata or {}

View File

@ -1,10 +1,14 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NamedTuple
from app.assets.database.models import Asset, AssetInfo
UserMetadata = dict[str, Any] | None
@dataclass(frozen=True)
class AssetData:
"""Plain data extracted from an Asset ORM object."""
hash: str
size_bytes: int | None
mime_type: str | None
@ -12,10 +16,9 @@ class AssetData:
@dataclass(frozen=True)
class AssetInfoData:
"""Plain data extracted from an AssetInfo ORM object."""
id: str
name: str
user_metadata: dict | None
user_metadata: UserMetadata
preview_id: str | None
created_at: datetime
updated_at: datetime
@ -24,7 +27,6 @@ class AssetInfoData:
@dataclass(frozen=True)
class AssetDetailResult:
"""Result from get_asset_detail and similar operations."""
info: AssetInfoData
asset: AssetData | None
tags: list[str]
@ -32,15 +34,49 @@ class AssetDetailResult:
@dataclass(frozen=True)
class RegisterAssetResult:
"""Result from register_existing_asset."""
info: AssetInfoData
asset: AssetData
tags: list[str]
created: bool
def extract_info_data(info) -> AssetInfoData:
"""Extract plain data from an AssetInfo ORM object."""
@dataclass(frozen=True)
class IngestResult:
asset_created: bool
asset_updated: bool
state_created: bool
state_updated: bool
asset_info_id: str | None
@dataclass(frozen=True)
class AddTagsResult:
added: list[str]
already_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class RemoveTagsResult:
removed: list[str]
not_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class SetTagsResult:
added: list[str]
removed: list[str]
total: list[str]
class TagUsage(NamedTuple):
name: str
tag_type: str
count: int
def extract_info_data(info: AssetInfo) -> AssetInfoData:
return AssetInfoData(
id=info.id,
name=info.name,
@ -52,8 +88,7 @@ def extract_info_data(info) -> AssetInfoData:
)
def extract_asset_data(asset) -> AssetData | None:
"""Extract plain data from an Asset ORM object."""
def extract_asset_data(asset: Asset | None) -> AssetData | None:
if asset is None:
return None
return AssetData(

View File

@ -4,6 +4,7 @@ from app.assets.database.queries import (
list_tags_with_usage,
remove_tags_from_asset_info,
)
from app.assets.services.schemas import AddTagsResult, RemoveTagsResult, TagUsage
from app.database.db import create_session
@ -12,11 +13,7 @@ def apply_tags(
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> dict:
"""
Add tags to an asset.
Returns dict with added, already_present, and total_tags lists.
"""
) -> AddTagsResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
@ -34,18 +31,18 @@ def apply_tags(
)
session.commit()
return data
return AddTagsResult(
added=data["added"],
already_present=data["already_present"],
total_tags=data["total_tags"],
)
def remove_tags(
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> dict:
"""
Remove tags from an asset.
Returns dict with removed, not_present, and total_tags lists.
"""
) -> RemoveTagsResult:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
@ -60,7 +57,11 @@ def remove_tags(
)
session.commit()
return data
return RemoveTagsResult(
removed=data["removed"],
not_present=data["not_present"],
total_tags=data["total_tags"],
)
def list_tags(
@ -70,11 +71,7 @@ def list_tags(
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
"""
List tags with usage counts.
Returns (rows, total) where rows are (name, tag_type, count) tuples.
"""
) -> tuple[list[TagUsage], int]:
limit = max(1, min(1000, limit))
offset = max(0, offset)
@ -89,4 +86,4 @@ def list_tags(
owner_id=owner_id,
)
return rows, total
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total

View File

@ -22,9 +22,9 @@ class TestIngestFileFromPath:
mime_type="application/octet-stream",
)
assert result["asset_created"] is True
assert result["state_created"] is True
assert result["asset_info_id"] is None # no info_name provided
assert result.asset_created is True
assert result.state_created is True
assert result.asset_info_id is None # no info_name provided
# Verify DB state
assets = session.query(Asset).all()
@ -49,8 +49,8 @@ class TestIngestFileFromPath:
owner_id="user1",
)
assert result["asset_created"] is True
assert result["asset_info_id"] is not None
assert result.asset_created is True
assert result.asset_info_id is not None
info = session.query(AssetInfo).first()
assert info is not None
@ -70,7 +70,7 @@ class TestIngestFileFromPath:
tags=["models", "checkpoints"],
)
assert result["asset_info_id"] is not None
assert result.asset_info_id is not None
# Verify tags were created and linked
tags = session.query(Tag).all()
@ -78,7 +78,7 @@ class TestIngestFileFromPath:
assert "models" in tag_names
assert "checkpoints" in tag_names
asset_tags = get_asset_tags(session, asset_info_id=result["asset_info_id"])
asset_tags = get_asset_tags(session, asset_info_id=result.asset_info_id)
assert set(asset_tags) == {"models", "checkpoints"}
def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session):
@ -92,7 +92,7 @@ class TestIngestFileFromPath:
size_bytes=7,
mtime_ns=1234567890000000000,
)
assert r1["asset_created"] is True
assert r1.asset_created is True
# Second ingest with same hash - should update, not create
r2 = ingest_file_from_path(
@ -101,8 +101,8 @@ class TestIngestFileFromPath:
size_bytes=7,
mtime_ns=1234567890000000001, # different mtime
)
assert r2["asset_created"] is False
assert r2["state_updated"] is True or r2["state_created"] is False
assert r2.asset_created is False
assert r2.state_updated is True or r2.state_created is False
# Still only one asset
assets = session.query(Asset).all()
@ -127,8 +127,8 @@ class TestIngestFileFromPath:
preview_id=preview_id,
)
assert result["asset_info_id"] is not None
info = session.query(AssetInfo).filter_by(id=result["asset_info_id"]).first()
assert result.asset_info_id is not None
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
assert info.preview_id == preview_id
def test_invalid_preview_id_is_cleared(self, mock_create_session, temp_dir: Path, session: Session):
@ -144,8 +144,8 @@ class TestIngestFileFromPath:
preview_id="nonexistent-uuid",
)
assert result["asset_info_id"] is not None
info = session.query(AssetInfo).filter_by(id=result["asset_info_id"]).first()
assert result.asset_info_id is not None
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
assert info.preview_id is None

View File

@ -46,9 +46,9 @@ class TestApplyTags:
tags=["alpha", "beta"],
)
assert set(result["added"]) == {"alpha", "beta"}
assert result["already_present"] == []
assert set(result["total_tags"]) == {"alpha", "beta"}
assert set(result.added) == {"alpha", "beta"}
assert result.already_present == []
assert set(result.total_tags) == {"alpha", "beta"}
def test_reports_already_present(self, mock_create_session, session: Session):
asset = _make_asset(session)
@ -62,8 +62,8 @@ class TestApplyTags:
tags=["existing", "new"],
)
assert result["added"] == ["new"]
assert result["already_present"] == ["existing"]
assert result.added == ["new"]
assert result.already_present == ["existing"]
def test_raises_for_nonexistent_info(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):
@ -95,9 +95,9 @@ class TestRemoveTags:
tags=["a", "b"],
)
assert set(result["removed"]) == {"a", "b"}
assert result["not_present"] == []
assert result["total_tags"] == ["c"]
assert set(result.removed) == {"a", "b"}
assert result.not_present == []
assert result.total_tags == ["c"]
def test_reports_not_present(self, mock_create_session, session: Session):
asset = _make_asset(session)
@ -111,8 +111,8 @@ class TestRemoveTags:
tags=["present", "absent"],
)
assert result["removed"] == ["present"]
assert result["not_present"] == ["absent"]
assert result.removed == ["present"]
assert result.not_present == ["absent"]
def test_raises_for_nonexistent_info(self, mock_create_session):
with pytest.raises(ValueError, match="not found"):