From 9290e26e9f79f37417bebf7bdf04b21ac038a0b3 Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Tue, 3 Feb 2026 20:32:14 -0800 Subject: [PATCH] refactor: add explicit types to asset service functions - Add typed result dataclasses: IngestResult, AddTagsResult, RemoveTagsResult, SetTagsResult, TagUsage - Add UserMetadata type alias for user_metadata parameters - Type helper functions with Session parameters - Use TypedDicts at query layer to avoid circular imports - Update manager.py and tests to use attribute access Co-Authored-By: Claude Opus 4.5 --- app/assets/database/queries/__init__.py | 6 ++ app/assets/database/queries/tags.py | 28 +++++-- app/assets/manager.py | 20 +++-- app/assets/services/__init__.py | 22 ++++++ app/assets/services/asset_management.py | 25 ++----- app/assets/services/ingest.py | 74 ++++++++----------- app/assets/services/schemas.py | 53 ++++++++++--- app/assets/services/tagging.py | 33 ++++----- .../assets_test/services/test_ingest.py | 28 +++---- .../assets_test/services/test_tagging.py | 20 ++--- 10 files changed, 184 insertions(+), 125 deletions(-) diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py index a9fecb378..a24f82f3c 100644 --- a/app/assets/database/queries/__init__.py +++ b/app/assets/database/queries/__init__.py @@ -37,6 +37,9 @@ from app.assets.database.queries.cache_state import ( upsert_cache_state, ) from app.assets.database.queries.tags import ( + AddTagsDict, + RemoveTagsDict, + SetTagsDict, add_missing_tag_for_asset_id, add_tags_to_asset_info, bulk_insert_tags_and_meta, @@ -90,4 +93,7 @@ __all__ = [ "remove_missing_tag_for_asset_id", "list_tags_with_usage", "bulk_insert_tags_and_meta", + "AddTagsDict", + "RemoveTagsDict", + "SetTagsDict", ] diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index 497e74870..7733d6e2b 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -1,4 +1,4 @@ -from typing import Iterable, Sequence +from typing import Iterable, Sequence, TypedDict import sqlalchemy as sa from sqlalchemy import delete, func, select @@ -9,6 +9,24 @@ from sqlalchemy.orm import Session from app.assets.database.models import AssetInfo, AssetInfoMeta, AssetInfoTag, Tag from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags + +class AddTagsDict(TypedDict): + added: list[str] + already_present: list[str] + total_tags: list[str] + + +class RemoveTagsDict(TypedDict): + removed: list[str] + not_present: list[str] + total_tags: list[str] + + +class SetTagsDict(TypedDict): + added: list[str] + removed: list[str] + total: list[str] + MAX_BIND_PARAMS = 800 @@ -60,7 +78,7 @@ def set_asset_info_tags( asset_info_id: str, tags: Sequence[str], origin: str = "manual", -) -> dict: +) -> SetTagsDict: desired = normalize_tags(tags) current = set( @@ -96,8 +114,8 @@ def add_tags_to_asset_info( tags: Sequence[str], origin: str = "manual", create_if_missing: bool = True, - asset_info_row = None, -) -> dict: + asset_info_row: AssetInfo | None = None, +) -> AddTagsDict: if not asset_info_row: info = session.get(AssetInfo, asset_info_id) if not info: @@ -153,7 +171,7 @@ def remove_tags_from_asset_info( session: Session, asset_info_id: str, tags: Sequence[str], -) -> dict: +) -> RemoveTagsDict: info = session.get(AssetInfo, asset_info_id) if not info: raise ValueError(f"AssetInfo {asset_info_id} not found") diff --git a/app/assets/manager.py b/app/assets/manager.py index ab1b77955..2da444352 100644 --- a/app/assets/manager.py +++ b/app/assets/manager.py @@ -268,7 +268,7 @@ def upload_asset_from_temp_path( tag_origin="manual", require_existing_tags=False, ) - info_id = result["asset_info_id"] + info_id = result.asset_info_id if not info_id: raise RuntimeError("failed to create asset metadata") @@ -290,7 +290,7 @@ def upload_asset_from_temp_path( preview_id=info.preview_id, created_at=info.created_at, last_access_time=info.last_access_time, - created_new=result["asset_created"], + created_new=result.asset_created, ) @@ -479,13 +479,17 @@ def add_tags_to_asset( origin: str = "manual", owner_id: str = "", ) -> schemas_out.TagsAdd: - data = apply_tags( + result = apply_tags( asset_info_id=asset_info_id, tags=tags, origin=origin, owner_id=owner_id, ) - return schemas_out.TagsAdd(**data) + return schemas_out.TagsAdd( + added=result.added, + already_present=result.already_present, + total_tags=result.total_tags, + ) def remove_tags_from_asset( @@ -493,12 +497,16 @@ def remove_tags_from_asset( tags: list[str], owner_id: str = "", ) -> schemas_out.TagsRemove: - data = remove_tags( + result = remove_tags( asset_info_id=asset_info_id, tags=tags, owner_id=owner_id, ) - return schemas_out.TagsRemove(**data) + return schemas_out.TagsRemove( + removed=result.removed, + not_present=result.not_present, + total_tags=result.total_tags, + ) def list_tags( diff --git a/app/assets/services/__init__.py b/app/assets/services/__init__.py index 5ce0ae0d1..7e4758a5f 100644 --- a/app/assets/services/__init__.py +++ b/app/assets/services/__init__.py @@ -8,6 +8,18 @@ from app.assets.services.ingest import ( ingest_file_from_path, register_existing_asset, ) +from app.assets.services.schemas import ( + AddTagsResult, + AssetData, + AssetDetailResult, + AssetInfoData, + IngestResult, + RegisterAssetResult, + RemoveTagsResult, + SetTagsResult, + TagUsage, + UserMetadata, +) from app.assets.services.tagging import ( apply_tags, list_tags, @@ -24,4 +36,14 @@ __all__ = [ "apply_tags", "remove_tags", "list_tags", + "AddTagsResult", + "AssetData", + "AssetDetailResult", + "AssetInfoData", + "IngestResult", + "RegisterAssetResult", + "RemoveTagsResult", + "SetTagsResult", + "TagUsage", + "UserMetadata", ] diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 4f23cf9af..e42cc728f 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -2,6 +2,8 @@ import contextlib import os from typing import Sequence +from sqlalchemy.orm import Session + from app.assets.database.models import Asset from app.assets.database.queries import ( asset_info_exists_for_asset_id, @@ -19,6 +21,7 @@ from app.assets.helpers import select_best_live_path from app.assets.services.path_utils import compute_relative_filename from app.assets.services.schemas import ( AssetDetailResult, + UserMetadata, extract_asset_data, extract_info_data, ) @@ -29,10 +32,6 @@ def get_asset_detail( asset_info_id: str, owner_id: str = "", ) -> AssetDetailResult | None: - """ - Fetch full asset details including tags. - Returns AssetDetailResult or None if not found. - """ with create_session() as session: result = fetch_asset_info_asset_and_tags( session, @@ -54,14 +53,10 @@ def update_asset_metadata( asset_info_id: str, name: str | None = None, tags: Sequence[str] | None = None, - user_metadata: dict | None = None, + user_metadata: UserMetadata = None, tag_origin: str = "manual", owner_id: str = "", ) -> AssetDetailResult: - """ - Update name, tags, and/or metadata on an AssetInfo. - Returns AssetDetailResult with updated data. - """ with create_session() as session: info = get_asset_info_by_id(session, asset_info_id=asset_info_id) if not info: @@ -128,11 +123,6 @@ def delete_asset_reference( owner_id: str, delete_content_if_orphan: bool = True, ) -> bool: - """ - Delete an AssetInfo reference. - If delete_content_if_orphan is True and no other AssetInfos reference the asset, - also delete the Asset and its cached files. - """ with create_session() as session: info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) asset_id = info_row.asset_id if info_row else None @@ -175,10 +165,6 @@ def set_asset_preview( preview_asset_id: str | None = None, owner_id: str = "", ) -> AssetDetailResult: - """ - Set or clear preview_id on an AssetInfo. - Returns AssetDetailResult with updated data. - """ with create_session() as session: info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) if not info_row: @@ -209,7 +195,6 @@ def set_asset_preview( return detail -def _compute_filename_for_asset(session, asset_id: str) -> str | None: - """Compute the relative filename for an asset from its cache states.""" +def _compute_filename_for_asset(session: Session, asset_id: str) -> str | None: primary_path = select_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset_id)) return compute_relative_filename(primary_path) if primary_path else None diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index cf88adee6..745f27704 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -3,8 +3,9 @@ import os from typing import Sequence from sqlalchemy import select +from sqlalchemy.orm import Session -from app.assets.database.models import Asset, Tag +from app.assets.database.models import Asset, AssetInfo, Tag from app.assets.database.queries import ( add_tags_to_asset_info, get_asset_by_hash, @@ -21,7 +22,9 @@ from app.assets.database.queries import ( from app.assets.helpers import normalize_tags, select_best_live_path from app.assets.services.path_utils import compute_relative_filename from app.assets.services.schemas import ( + IngestResult, RegisterAssetResult, + UserMetadata, extract_asset_data, extract_info_data, ) @@ -37,41 +40,30 @@ def ingest_file_from_path( info_name: str | None = None, owner_id: str = "", preview_id: str | None = None, - user_metadata: dict | None = None, + user_metadata: UserMetadata = None, tags: Sequence[str] = (), tag_origin: str = "manual", require_existing_tags: bool = False, -) -> dict: - """ - Idempotently upsert: - - Asset by content hash (create if missing) - - AssetCacheState(file_path) pointing to asset_id - - Optionally AssetInfo + tag links and metadata projection - Returns flags and ids. - """ +) -> IngestResult: locator = os.path.abspath(abs_path) - out: dict = { - "asset_created": False, - "asset_updated": False, - "state_created": False, - "state_updated": False, - "asset_info_id": None, - } + asset_created = False + asset_updated = False + state_created = False + state_updated = False + asset_info_id: str | None = None with create_session() as session: if preview_id: if not session.get(Asset, preview_id): preview_id = None - asset, created, updated = upsert_asset( + asset, asset_created, asset_updated = upsert_asset( session, asset_hash=asset_hash, size_bytes=size_bytes, mime_type=mime_type, ) - out["asset_created"] = created - out["asset_updated"] = updated state_created, state_updated = upsert_cache_state( session, @@ -79,8 +71,6 @@ def ingest_file_from_path( file_path=locator, mtime_ns=mtime_ns, ) - out["state_created"] = state_created - out["state_updated"] = state_updated if info_name: info, info_created = get_or_create_asset_info( @@ -91,27 +81,27 @@ def ingest_file_from_path( preview_id=preview_id, ) if info_created: - out["asset_info_id"] = info.id + asset_info_id = info.id else: update_asset_info_timestamps(session, asset_info=info, preview_id=preview_id) - out["asset_info_id"] = info.id + asset_info_id = info.id norm = normalize_tags(list(tags)) - if norm and out["asset_info_id"]: + if norm and asset_info_id: if require_existing_tags: _validate_tags_exist(session, norm) add_tags_to_asset_info( session, - asset_info_id=out["asset_info_id"], + asset_info_id=asset_info_id, tags=norm, origin=tag_origin, create_if_missing=not require_existing_tags, ) - if out["asset_info_id"]: + if asset_info_id: _update_metadata_with_filename( session, - asset_info_id=out["asset_info_id"], + asset_info_id=asset_info_id, asset_id=asset.id, info=info, user_metadata=user_metadata, @@ -124,22 +114,23 @@ def ingest_file_from_path( session.commit() - return out + return IngestResult( + asset_created=asset_created, + asset_updated=asset_updated, + state_created=state_created, + state_updated=state_updated, + asset_info_id=asset_info_id, + ) def register_existing_asset( asset_hash: str, name: str, - user_metadata: dict | None = None, + user_metadata: UserMetadata = None, tags: list[str] | None = None, tag_origin: str = "manual", owner_id: str = "", ) -> RegisterAssetResult: - """ - Create or return existing AssetInfo for an asset that already exists by hash. - Returns RegisterAssetResult with plain data. - Raises ValueError if hash not found. - """ with create_session() as session: asset = get_asset_by_hash(session, asset_hash=asset_hash) if not asset: @@ -197,8 +188,7 @@ def register_existing_asset( return result -def _validate_tags_exist(session, tags: list[str]) -> None: - """Raise ValueError if any tags don't exist.""" +def _validate_tags_exist(session: Session, tags: list[str]) -> None: existing_tag_names = set( name for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all() ) @@ -207,20 +197,18 @@ def _validate_tags_exist(session, tags: list[str]) -> None: raise ValueError(f"Unknown tags: {missing}") -def _compute_filename_for_asset(session, asset_id: str) -> str | None: - """Compute the relative filename for an asset from its cache states.""" +def _compute_filename_for_asset(session: Session, asset_id: str) -> str | None: primary_path = select_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset_id)) return compute_relative_filename(primary_path) if primary_path else None def _update_metadata_with_filename( - session, + session: Session, asset_info_id: str, asset_id: str, - info, - user_metadata: dict | None, + info: AssetInfo, + user_metadata: UserMetadata, ) -> None: - """Update metadata projection with computed filename.""" computed_filename = _compute_filename_for_asset(session, asset_id) current_meta = info.user_metadata or {} diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py index a0fd02553..8727f5732 100644 --- a/app/assets/services/schemas.py +++ b/app/assets/services/schemas.py @@ -1,10 +1,14 @@ from dataclasses import dataclass from datetime import datetime +from typing import Any, NamedTuple + +from app.assets.database.models import Asset, AssetInfo + +UserMetadata = dict[str, Any] | None @dataclass(frozen=True) class AssetData: - """Plain data extracted from an Asset ORM object.""" hash: str size_bytes: int | None mime_type: str | None @@ -12,10 +16,9 @@ class AssetData: @dataclass(frozen=True) class AssetInfoData: - """Plain data extracted from an AssetInfo ORM object.""" id: str name: str - user_metadata: dict | None + user_metadata: UserMetadata preview_id: str | None created_at: datetime updated_at: datetime @@ -24,7 +27,6 @@ class AssetInfoData: @dataclass(frozen=True) class AssetDetailResult: - """Result from get_asset_detail and similar operations.""" info: AssetInfoData asset: AssetData | None tags: list[str] @@ -32,15 +34,49 @@ class AssetDetailResult: @dataclass(frozen=True) class RegisterAssetResult: - """Result from register_existing_asset.""" info: AssetInfoData asset: AssetData tags: list[str] created: bool -def extract_info_data(info) -> AssetInfoData: - """Extract plain data from an AssetInfo ORM object.""" +@dataclass(frozen=True) +class IngestResult: + asset_created: bool + asset_updated: bool + state_created: bool + state_updated: bool + asset_info_id: str | None + + +@dataclass(frozen=True) +class AddTagsResult: + added: list[str] + already_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class RemoveTagsResult: + removed: list[str] + not_present: list[str] + total_tags: list[str] + + +@dataclass(frozen=True) +class SetTagsResult: + added: list[str] + removed: list[str] + total: list[str] + + +class TagUsage(NamedTuple): + name: str + tag_type: str + count: int + + +def extract_info_data(info: AssetInfo) -> AssetInfoData: return AssetInfoData( id=info.id, name=info.name, @@ -52,8 +88,7 @@ def extract_info_data(info) -> AssetInfoData: ) -def extract_asset_data(asset) -> AssetData | None: - """Extract plain data from an Asset ORM object.""" +def extract_asset_data(asset: Asset | None) -> AssetData | None: if asset is None: return None return AssetData( diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py index b9f5a8c69..d46b3599e 100644 --- a/app/assets/services/tagging.py +++ b/app/assets/services/tagging.py @@ -4,6 +4,7 @@ from app.assets.database.queries import ( list_tags_with_usage, remove_tags_from_asset_info, ) +from app.assets.services.schemas import AddTagsResult, RemoveTagsResult, TagUsage from app.database.db import create_session @@ -12,11 +13,7 @@ def apply_tags( tags: list[str], origin: str = "manual", owner_id: str = "", -) -> dict: - """ - Add tags to an asset. - Returns dict with added, already_present, and total_tags lists. - """ +) -> AddTagsResult: with create_session() as session: info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) if not info_row: @@ -34,18 +31,18 @@ def apply_tags( ) session.commit() - return data + return AddTagsResult( + added=data["added"], + already_present=data["already_present"], + total_tags=data["total_tags"], + ) def remove_tags( asset_info_id: str, tags: list[str], owner_id: str = "", -) -> dict: - """ - Remove tags from an asset. - Returns dict with removed, not_present, and total_tags lists. - """ +) -> RemoveTagsResult: with create_session() as session: info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) if not info_row: @@ -60,7 +57,11 @@ def remove_tags( ) session.commit() - return data + return RemoveTagsResult( + removed=data["removed"], + not_present=data["not_present"], + total_tags=data["total_tags"], + ) def list_tags( @@ -70,11 +71,7 @@ def list_tags( order: str = "count_desc", include_zero: bool = True, owner_id: str = "", -) -> tuple[list[tuple[str, str, int]], int]: - """ - List tags with usage counts. - Returns (rows, total) where rows are (name, tag_type, count) tuples. - """ +) -> tuple[list[TagUsage], int]: limit = max(1, min(1000, limit)) offset = max(0, offset) @@ -89,4 +86,4 @@ def list_tags( owner_id=owner_id, ) - return rows, total + return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total diff --git a/tests-unit/assets_test/services/test_ingest.py b/tests-unit/assets_test/services/test_ingest.py index d1817ff3a..3c47f7207 100644 --- a/tests-unit/assets_test/services/test_ingest.py +++ b/tests-unit/assets_test/services/test_ingest.py @@ -22,9 +22,9 @@ class TestIngestFileFromPath: mime_type="application/octet-stream", ) - assert result["asset_created"] is True - assert result["state_created"] is True - assert result["asset_info_id"] is None # no info_name provided + assert result.asset_created is True + assert result.state_created is True + assert result.asset_info_id is None # no info_name provided # Verify DB state assets = session.query(Asset).all() @@ -49,8 +49,8 @@ class TestIngestFileFromPath: owner_id="user1", ) - assert result["asset_created"] is True - assert result["asset_info_id"] is not None + assert result.asset_created is True + assert result.asset_info_id is not None info = session.query(AssetInfo).first() assert info is not None @@ -70,7 +70,7 @@ class TestIngestFileFromPath: tags=["models", "checkpoints"], ) - assert result["asset_info_id"] is not None + assert result.asset_info_id is not None # Verify tags were created and linked tags = session.query(Tag).all() @@ -78,7 +78,7 @@ class TestIngestFileFromPath: assert "models" in tag_names assert "checkpoints" in tag_names - asset_tags = get_asset_tags(session, asset_info_id=result["asset_info_id"]) + asset_tags = get_asset_tags(session, asset_info_id=result.asset_info_id) assert set(asset_tags) == {"models", "checkpoints"} def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session): @@ -92,7 +92,7 @@ class TestIngestFileFromPath: size_bytes=7, mtime_ns=1234567890000000000, ) - assert r1["asset_created"] is True + assert r1.asset_created is True # Second ingest with same hash - should update, not create r2 = ingest_file_from_path( @@ -101,8 +101,8 @@ class TestIngestFileFromPath: size_bytes=7, mtime_ns=1234567890000000001, # different mtime ) - assert r2["asset_created"] is False - assert r2["state_updated"] is True or r2["state_created"] is False + assert r2.asset_created is False + assert r2.state_updated is True or r2.state_created is False # Still only one asset assets = session.query(Asset).all() @@ -127,8 +127,8 @@ class TestIngestFileFromPath: preview_id=preview_id, ) - assert result["asset_info_id"] is not None - info = session.query(AssetInfo).filter_by(id=result["asset_info_id"]).first() + assert result.asset_info_id is not None + info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first() assert info.preview_id == preview_id def test_invalid_preview_id_is_cleared(self, mock_create_session, temp_dir: Path, session: Session): @@ -144,8 +144,8 @@ class TestIngestFileFromPath: preview_id="nonexistent-uuid", ) - assert result["asset_info_id"] is not None - info = session.query(AssetInfo).filter_by(id=result["asset_info_id"]).first() + assert result.asset_info_id is not None + info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first() assert info.preview_id is None diff --git a/tests-unit/assets_test/services/test_tagging.py b/tests-unit/assets_test/services/test_tagging.py index d9e7b2a5b..b3be52244 100644 --- a/tests-unit/assets_test/services/test_tagging.py +++ b/tests-unit/assets_test/services/test_tagging.py @@ -46,9 +46,9 @@ class TestApplyTags: tags=["alpha", "beta"], ) - assert set(result["added"]) == {"alpha", "beta"} - assert result["already_present"] == [] - assert set(result["total_tags"]) == {"alpha", "beta"} + assert set(result.added) == {"alpha", "beta"} + assert result.already_present == [] + assert set(result.total_tags) == {"alpha", "beta"} def test_reports_already_present(self, mock_create_session, session: Session): asset = _make_asset(session) @@ -62,8 +62,8 @@ class TestApplyTags: tags=["existing", "new"], ) - assert result["added"] == ["new"] - assert result["already_present"] == ["existing"] + assert result.added == ["new"] + assert result.already_present == ["existing"] def test_raises_for_nonexistent_info(self, mock_create_session): with pytest.raises(ValueError, match="not found"): @@ -95,9 +95,9 @@ class TestRemoveTags: tags=["a", "b"], ) - assert set(result["removed"]) == {"a", "b"} - assert result["not_present"] == [] - assert result["total_tags"] == ["c"] + assert set(result.removed) == {"a", "b"} + assert result.not_present == [] + assert result.total_tags == ["c"] def test_reports_not_present(self, mock_create_session, session: Session): asset = _make_asset(session) @@ -111,8 +111,8 @@ class TestRemoveTags: tags=["present", "absent"], ) - assert result["removed"] == ["present"] - assert result["not_present"] == ["absent"] + assert result.removed == ["present"] + assert result.not_present == ["absent"] def test_raises_for_nonexistent_info(self, mock_create_session): with pytest.raises(ValueError, match="not found"):