mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 11:32:31 +08:00
refactor: add explicit types to asset service functions
- Add typed result dataclasses: IngestResult, AddTagsResult, RemoveTagsResult, SetTagsResult, TagUsage - Add UserMetadata type alias for user_metadata parameters - Type helper functions with Session parameters - Use TypedDicts at query layer to avoid circular imports - Update manager.py and tests to use attribute access Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
This commit is contained in:
parent
37ecc5b663
commit
9290e26e9f
@ -37,6 +37,9 @@ from app.assets.database.queries.cache_state import (
|
||||
upsert_cache_state,
|
||||
)
|
||||
from app.assets.database.queries.tags import (
|
||||
AddTagsDict,
|
||||
RemoveTagsDict,
|
||||
SetTagsDict,
|
||||
add_missing_tag_for_asset_id,
|
||||
add_tags_to_asset_info,
|
||||
bulk_insert_tags_and_meta,
|
||||
@ -90,4 +93,7 @@ __all__ = [
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"list_tags_with_usage",
|
||||
"bulk_insert_tags_and_meta",
|
||||
"AddTagsDict",
|
||||
"RemoveTagsDict",
|
||||
"SetTagsDict",
|
||||
]
|
||||
|
||||
@ -1,4 +1,4 @@
|
||||
from typing import Iterable, Sequence
|
||||
from typing import Iterable, Sequence, TypedDict
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
@ -9,6 +9,24 @@ from sqlalchemy.orm import Session
|
||||
from app.assets.database.models import AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
|
||||
|
||||
class AddTagsDict(TypedDict):
|
||||
added: list[str]
|
||||
already_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
class RemoveTagsDict(TypedDict):
|
||||
removed: list[str]
|
||||
not_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
class SetTagsDict(TypedDict):
|
||||
added: list[str]
|
||||
removed: list[str]
|
||||
total: list[str]
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
|
||||
@ -60,7 +78,7 @@ def set_asset_info_tags(
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
) -> SetTagsDict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
@ -96,8 +114,8 @@ def add_tags_to_asset_info(
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row = None,
|
||||
) -> dict:
|
||||
asset_info_row: AssetInfo | None = None,
|
||||
) -> AddTagsDict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
@ -153,7 +171,7 @@ def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
) -> RemoveTagsDict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
@ -268,7 +268,7 @@ def upload_asset_from_temp_path(
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
info_id = result.asset_info_id
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
@ -290,7 +290,7 @@ def upload_asset_from_temp_path(
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
created_new=result.asset_created,
|
||||
)
|
||||
|
||||
|
||||
@ -479,13 +479,17 @@ def add_tags_to_asset(
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
data = apply_tags(
|
||||
result = apply_tags(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
return schemas_out.TagsAdd(**data)
|
||||
return schemas_out.TagsAdd(
|
||||
added=result.added,
|
||||
already_present=result.already_present,
|
||||
total_tags=result.total_tags,
|
||||
)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
@ -493,12 +497,16 @@ def remove_tags_from_asset(
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
data = remove_tags(
|
||||
result = remove_tags(
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
return schemas_out.TagsRemove(**data)
|
||||
return schemas_out.TagsRemove(
|
||||
removed=result.removed,
|
||||
not_present=result.not_present,
|
||||
total_tags=result.total_tags,
|
||||
)
|
||||
|
||||
|
||||
def list_tags(
|
||||
|
||||
@ -8,6 +8,18 @@ from app.assets.services.ingest import (
|
||||
ingest_file_from_path,
|
||||
register_existing_asset,
|
||||
)
|
||||
from app.assets.services.schemas import (
|
||||
AddTagsResult,
|
||||
AssetData,
|
||||
AssetDetailResult,
|
||||
AssetInfoData,
|
||||
IngestResult,
|
||||
RegisterAssetResult,
|
||||
RemoveTagsResult,
|
||||
SetTagsResult,
|
||||
TagUsage,
|
||||
UserMetadata,
|
||||
)
|
||||
from app.assets.services.tagging import (
|
||||
apply_tags,
|
||||
list_tags,
|
||||
@ -24,4 +36,14 @@ __all__ = [
|
||||
"apply_tags",
|
||||
"remove_tags",
|
||||
"list_tags",
|
||||
"AddTagsResult",
|
||||
"AssetData",
|
||||
"AssetDetailResult",
|
||||
"AssetInfoData",
|
||||
"IngestResult",
|
||||
"RegisterAssetResult",
|
||||
"RemoveTagsResult",
|
||||
"SetTagsResult",
|
||||
"TagUsage",
|
||||
"UserMetadata",
|
||||
]
|
||||
|
||||
@ -2,6 +2,8 @@ import contextlib
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
asset_info_exists_for_asset_id,
|
||||
@ -19,6 +21,7 @@ from app.assets.helpers import select_best_live_path
|
||||
from app.assets.services.path_utils import compute_relative_filename
|
||||
from app.assets.services.schemas import (
|
||||
AssetDetailResult,
|
||||
UserMetadata,
|
||||
extract_asset_data,
|
||||
extract_info_data,
|
||||
)
|
||||
@ -29,10 +32,6 @@ def get_asset_detail(
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult | None:
|
||||
"""
|
||||
Fetch full asset details including tags.
|
||||
Returns AssetDetailResult or None if not found.
|
||||
"""
|
||||
with create_session() as session:
|
||||
result = fetch_asset_info_asset_and_tags(
|
||||
session,
|
||||
@ -54,14 +53,10 @@ def update_asset_metadata(
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
user_metadata: UserMetadata = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
"""
|
||||
Update name, tags, and/or metadata on an AssetInfo.
|
||||
Returns AssetDetailResult with updated data.
|
||||
"""
|
||||
with create_session() as session:
|
||||
info = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info:
|
||||
@ -128,11 +123,6 @@ def delete_asset_reference(
|
||||
owner_id: str,
|
||||
delete_content_if_orphan: bool = True,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete an AssetInfo reference.
|
||||
If delete_content_if_orphan is True and no other AssetInfos reference the asset,
|
||||
also delete the Asset and its cached files.
|
||||
"""
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
@ -175,10 +165,6 @@ def set_asset_preview(
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
"""
|
||||
Set or clear preview_id on an AssetInfo.
|
||||
Returns AssetDetailResult with updated data.
|
||||
"""
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
@ -209,7 +195,6 @@ def set_asset_preview(
|
||||
return detail
|
||||
|
||||
|
||||
def _compute_filename_for_asset(session, asset_id: str) -> str | None:
|
||||
"""Compute the relative filename for an asset from its cache states."""
|
||||
def _compute_filename_for_asset(session: Session, asset_id: str) -> str | None:
|
||||
primary_path = select_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset_id))
|
||||
return compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
@ -3,8 +3,9 @@ import os
|
||||
from typing import Sequence
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, Tag
|
||||
from app.assets.database.models import Asset, AssetInfo, Tag
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_asset_info,
|
||||
get_asset_by_hash,
|
||||
@ -21,7 +22,9 @@ from app.assets.database.queries import (
|
||||
from app.assets.helpers import normalize_tags, select_best_live_path
|
||||
from app.assets.services.path_utils import compute_relative_filename
|
||||
from app.assets.services.schemas import (
|
||||
IngestResult,
|
||||
RegisterAssetResult,
|
||||
UserMetadata,
|
||||
extract_asset_data,
|
||||
extract_info_data,
|
||||
)
|
||||
@ -37,41 +40,30 @@ def ingest_file_from_path(
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
user_metadata: UserMetadata = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
) -> IngestResult:
|
||||
locator = os.path.abspath(abs_path)
|
||||
|
||||
out: dict = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
asset_created = False
|
||||
asset_updated = False
|
||||
state_created = False
|
||||
state_updated = False
|
||||
asset_info_id: str | None = None
|
||||
|
||||
with create_session() as session:
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
asset, created, updated = upsert_asset(
|
||||
asset, asset_created, asset_updated = upsert_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=size_bytes,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
out["asset_created"] = created
|
||||
out["asset_updated"] = updated
|
||||
|
||||
state_created, state_updated = upsert_cache_state(
|
||||
session,
|
||||
@ -79,8 +71,6 @@ def ingest_file_from_path(
|
||||
file_path=locator,
|
||||
mtime_ns=mtime_ns,
|
||||
)
|
||||
out["state_created"] = state_created
|
||||
out["state_updated"] = state_updated
|
||||
|
||||
if info_name:
|
||||
info, info_created = get_or_create_asset_info(
|
||||
@ -91,27 +81,27 @@ def ingest_file_from_path(
|
||||
preview_id=preview_id,
|
||||
)
|
||||
if info_created:
|
||||
out["asset_info_id"] = info.id
|
||||
asset_info_id = info.id
|
||||
else:
|
||||
update_asset_info_timestamps(session, asset_info=info, preview_id=preview_id)
|
||||
out["asset_info_id"] = info.id
|
||||
asset_info_id = info.id
|
||||
|
||||
norm = normalize_tags(list(tags))
|
||||
if norm and out["asset_info_id"]:
|
||||
if norm and asset_info_id:
|
||||
if require_existing_tags:
|
||||
_validate_tags_exist(session, norm)
|
||||
add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
asset_info_id=asset_info_id,
|
||||
tags=norm,
|
||||
origin=tag_origin,
|
||||
create_if_missing=not require_existing_tags,
|
||||
)
|
||||
|
||||
if out["asset_info_id"]:
|
||||
if asset_info_id:
|
||||
_update_metadata_with_filename(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
asset_info_id=asset_info_id,
|
||||
asset_id=asset.id,
|
||||
info=info,
|
||||
user_metadata=user_metadata,
|
||||
@ -124,22 +114,23 @@ def ingest_file_from_path(
|
||||
|
||||
session.commit()
|
||||
|
||||
return out
|
||||
return IngestResult(
|
||||
asset_created=asset_created,
|
||||
asset_updated=asset_updated,
|
||||
state_created=state_created,
|
||||
state_updated=state_updated,
|
||||
asset_info_id=asset_info_id,
|
||||
)
|
||||
|
||||
|
||||
def register_existing_asset(
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
user_metadata: UserMetadata = None,
|
||||
tags: list[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> RegisterAssetResult:
|
||||
"""
|
||||
Create or return existing AssetInfo for an asset that already exists by hash.
|
||||
Returns RegisterAssetResult with plain data.
|
||||
Raises ValueError if hash not found.
|
||||
"""
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
@ -197,8 +188,7 @@ def register_existing_asset(
|
||||
return result
|
||||
|
||||
|
||||
def _validate_tags_exist(session, tags: list[str]) -> None:
|
||||
"""Raise ValueError if any tags don't exist."""
|
||||
def _validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
existing_tag_names = set(
|
||||
name for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
|
||||
)
|
||||
@ -207,20 +197,18 @@ def _validate_tags_exist(session, tags: list[str]) -> None:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
|
||||
def _compute_filename_for_asset(session, asset_id: str) -> str | None:
|
||||
"""Compute the relative filename for an asset from its cache states."""
|
||||
def _compute_filename_for_asset(session: Session, asset_id: str) -> str | None:
|
||||
primary_path = select_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset_id))
|
||||
return compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
|
||||
def _update_metadata_with_filename(
|
||||
session,
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
asset_id: str,
|
||||
info,
|
||||
user_metadata: dict | None,
|
||||
info: AssetInfo,
|
||||
user_metadata: UserMetadata,
|
||||
) -> None:
|
||||
"""Update metadata projection with computed filename."""
|
||||
computed_filename = _compute_filename_for_asset(session, asset_id)
|
||||
|
||||
current_meta = info.user_metadata or {}
|
||||
|
||||
@ -1,10 +1,14 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from app.assets.database.models import Asset, AssetInfo
|
||||
|
||||
UserMetadata = dict[str, Any] | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetData:
|
||||
"""Plain data extracted from an Asset ORM object."""
|
||||
hash: str
|
||||
size_bytes: int | None
|
||||
mime_type: str | None
|
||||
@ -12,10 +16,9 @@ class AssetData:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetInfoData:
|
||||
"""Plain data extracted from an AssetInfo ORM object."""
|
||||
id: str
|
||||
name: str
|
||||
user_metadata: dict | None
|
||||
user_metadata: UserMetadata
|
||||
preview_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
@ -24,7 +27,6 @@ class AssetInfoData:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetDetailResult:
|
||||
"""Result from get_asset_detail and similar operations."""
|
||||
info: AssetInfoData
|
||||
asset: AssetData | None
|
||||
tags: list[str]
|
||||
@ -32,15 +34,49 @@ class AssetDetailResult:
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegisterAssetResult:
|
||||
"""Result from register_existing_asset."""
|
||||
info: AssetInfoData
|
||||
asset: AssetData
|
||||
tags: list[str]
|
||||
created: bool
|
||||
|
||||
|
||||
def extract_info_data(info) -> AssetInfoData:
|
||||
"""Extract plain data from an AssetInfo ORM object."""
|
||||
@dataclass(frozen=True)
|
||||
class IngestResult:
|
||||
asset_created: bool
|
||||
asset_updated: bool
|
||||
state_created: bool
|
||||
state_updated: bool
|
||||
asset_info_id: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AddTagsResult:
|
||||
added: list[str]
|
||||
already_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoveTagsResult:
|
||||
removed: list[str]
|
||||
not_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SetTagsResult:
|
||||
added: list[str]
|
||||
removed: list[str]
|
||||
total: list[str]
|
||||
|
||||
|
||||
class TagUsage(NamedTuple):
|
||||
name: str
|
||||
tag_type: str
|
||||
count: int
|
||||
|
||||
|
||||
def extract_info_data(info: AssetInfo) -> AssetInfoData:
|
||||
return AssetInfoData(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
@ -52,8 +88,7 @@ def extract_info_data(info) -> AssetInfoData:
|
||||
)
|
||||
|
||||
|
||||
def extract_asset_data(asset) -> AssetData | None:
|
||||
"""Extract plain data from an Asset ORM object."""
|
||||
def extract_asset_data(asset: Asset | None) -> AssetData | None:
|
||||
if asset is None:
|
||||
return None
|
||||
return AssetData(
|
||||
|
||||
@ -4,6 +4,7 @@ from app.assets.database.queries import (
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_asset_info,
|
||||
)
|
||||
from app.assets.services.schemas import AddTagsResult, RemoveTagsResult, TagUsage
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
@ -12,11 +13,7 @@ def apply_tags(
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Add tags to an asset.
|
||||
Returns dict with added, already_present, and total_tags lists.
|
||||
"""
|
||||
) -> AddTagsResult:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
@ -34,18 +31,18 @@ def apply_tags(
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return data
|
||||
return AddTagsResult(
|
||||
added=data["added"],
|
||||
already_present=data["already_present"],
|
||||
total_tags=data["total_tags"],
|
||||
)
|
||||
|
||||
|
||||
def remove_tags(
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""
|
||||
Remove tags from an asset.
|
||||
Returns dict with removed, not_present, and total_tags lists.
|
||||
"""
|
||||
) -> RemoveTagsResult:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
@ -60,7 +57,11 @@ def remove_tags(
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return data
|
||||
return RemoveTagsResult(
|
||||
removed=data["removed"],
|
||||
not_present=data["not_present"],
|
||||
total_tags=data["total_tags"],
|
||||
)
|
||||
|
||||
|
||||
def list_tags(
|
||||
@ -70,11 +71,7 @@ def list_tags(
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
"""
|
||||
List tags with usage counts.
|
||||
Returns (rows, total) where rows are (name, tag_type, count) tuples.
|
||||
"""
|
||||
) -> tuple[list[TagUsage], int]:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
@ -89,4 +86,4 @@ def list_tags(
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return rows, total
|
||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||
|
||||
@ -22,9 +22,9 @@ class TestIngestFileFromPath:
|
||||
mime_type="application/octet-stream",
|
||||
)
|
||||
|
||||
assert result["asset_created"] is True
|
||||
assert result["state_created"] is True
|
||||
assert result["asset_info_id"] is None # no info_name provided
|
||||
assert result.asset_created is True
|
||||
assert result.state_created is True
|
||||
assert result.asset_info_id is None # no info_name provided
|
||||
|
||||
# Verify DB state
|
||||
assets = session.query(Asset).all()
|
||||
@ -49,8 +49,8 @@ class TestIngestFileFromPath:
|
||||
owner_id="user1",
|
||||
)
|
||||
|
||||
assert result["asset_created"] is True
|
||||
assert result["asset_info_id"] is not None
|
||||
assert result.asset_created is True
|
||||
assert result.asset_info_id is not None
|
||||
|
||||
info = session.query(AssetInfo).first()
|
||||
assert info is not None
|
||||
@ -70,7 +70,7 @@ class TestIngestFileFromPath:
|
||||
tags=["models", "checkpoints"],
|
||||
)
|
||||
|
||||
assert result["asset_info_id"] is not None
|
||||
assert result.asset_info_id is not None
|
||||
|
||||
# Verify tags were created and linked
|
||||
tags = session.query(Tag).all()
|
||||
@ -78,7 +78,7 @@ class TestIngestFileFromPath:
|
||||
assert "models" in tag_names
|
||||
assert "checkpoints" in tag_names
|
||||
|
||||
asset_tags = get_asset_tags(session, asset_info_id=result["asset_info_id"])
|
||||
asset_tags = get_asset_tags(session, asset_info_id=result.asset_info_id)
|
||||
assert set(asset_tags) == {"models", "checkpoints"}
|
||||
|
||||
def test_idempotent_upsert(self, mock_create_session, temp_dir: Path, session: Session):
|
||||
@ -92,7 +92,7 @@ class TestIngestFileFromPath:
|
||||
size_bytes=7,
|
||||
mtime_ns=1234567890000000000,
|
||||
)
|
||||
assert r1["asset_created"] is True
|
||||
assert r1.asset_created is True
|
||||
|
||||
# Second ingest with same hash - should update, not create
|
||||
r2 = ingest_file_from_path(
|
||||
@ -101,8 +101,8 @@ class TestIngestFileFromPath:
|
||||
size_bytes=7,
|
||||
mtime_ns=1234567890000000001, # different mtime
|
||||
)
|
||||
assert r2["asset_created"] is False
|
||||
assert r2["state_updated"] is True or r2["state_created"] is False
|
||||
assert r2.asset_created is False
|
||||
assert r2.state_updated is True or r2.state_created is False
|
||||
|
||||
# Still only one asset
|
||||
assets = session.query(Asset).all()
|
||||
@ -127,8 +127,8 @@ class TestIngestFileFromPath:
|
||||
preview_id=preview_id,
|
||||
)
|
||||
|
||||
assert result["asset_info_id"] is not None
|
||||
info = session.query(AssetInfo).filter_by(id=result["asset_info_id"]).first()
|
||||
assert result.asset_info_id is not None
|
||||
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
|
||||
assert info.preview_id == preview_id
|
||||
|
||||
def test_invalid_preview_id_is_cleared(self, mock_create_session, temp_dir: Path, session: Session):
|
||||
@ -144,8 +144,8 @@ class TestIngestFileFromPath:
|
||||
preview_id="nonexistent-uuid",
|
||||
)
|
||||
|
||||
assert result["asset_info_id"] is not None
|
||||
info = session.query(AssetInfo).filter_by(id=result["asset_info_id"]).first()
|
||||
assert result.asset_info_id is not None
|
||||
info = session.query(AssetInfo).filter_by(id=result.asset_info_id).first()
|
||||
assert info.preview_id is None
|
||||
|
||||
|
||||
|
||||
@ -46,9 +46,9 @@ class TestApplyTags:
|
||||
tags=["alpha", "beta"],
|
||||
)
|
||||
|
||||
assert set(result["added"]) == {"alpha", "beta"}
|
||||
assert result["already_present"] == []
|
||||
assert set(result["total_tags"]) == {"alpha", "beta"}
|
||||
assert set(result.added) == {"alpha", "beta"}
|
||||
assert result.already_present == []
|
||||
assert set(result.total_tags) == {"alpha", "beta"}
|
||||
|
||||
def test_reports_already_present(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
@ -62,8 +62,8 @@ class TestApplyTags:
|
||||
tags=["existing", "new"],
|
||||
)
|
||||
|
||||
assert result["added"] == ["new"]
|
||||
assert result["already_present"] == ["existing"]
|
||||
assert result.added == ["new"]
|
||||
assert result.already_present == ["existing"]
|
||||
|
||||
def test_raises_for_nonexistent_info(self, mock_create_session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
@ -95,9 +95,9 @@ class TestRemoveTags:
|
||||
tags=["a", "b"],
|
||||
)
|
||||
|
||||
assert set(result["removed"]) == {"a", "b"}
|
||||
assert result["not_present"] == []
|
||||
assert result["total_tags"] == ["c"]
|
||||
assert set(result.removed) == {"a", "b"}
|
||||
assert result.not_present == []
|
||||
assert result.total_tags == ["c"]
|
||||
|
||||
def test_reports_not_present(self, mock_create_session, session: Session):
|
||||
asset = _make_asset(session)
|
||||
@ -111,8 +111,8 @@ class TestRemoveTags:
|
||||
tags=["present", "absent"],
|
||||
)
|
||||
|
||||
assert result["removed"] == ["present"]
|
||||
assert result["not_present"] == ["absent"]
|
||||
assert result.removed == ["present"]
|
||||
assert result.not_present == ["absent"]
|
||||
|
||||
def test_raises_for_nonexistent_info(self, mock_create_session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
|
||||
Loading…
Reference in New Issue
Block a user