feat(assets): align local API with cloud spec

Unify response models, add missing fields, and align input schemas with
the cloud OpenAPI spec at cloud.comfy.org/openapi.

- Replace AssetSummary/AssetDetail/AssetUpdated with single Asset model
- Add is_immutable, metadata (system_metadata), prompt_id fields
- Support mime_type and preview_id in update endpoint
- Make CreateFromHashBody.name optional, add mime_type, require >=1 tag
- Add id/mime_type/preview_id to upload, relax tags to optional
- Rename total_tags → tags in tag add/remove responses
- Add GET /api/assets/tags/refine histogram endpoint
- Add DB migration for system_metadata and prompt_id columns

Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
This commit is contained in:
Luke Mino-Altherr 2026-03-09 22:36:00 -07:00
parent 593be209a4
commit 07d17edbfd
15 changed files with 426 additions and 210 deletions

View File

@ -0,0 +1,31 @@
"""
Add system_metadata and prompt_id columns to asset_references.
Revision ID: 0003_add_metadata_prompt
Revises: 0002_merge_to_asset_references
Create Date: 2026-03-09
"""
from alembic import op
import sqlalchemy as sa
revision = "0003_add_metadata_prompt"
down_revision = "0002_merge_to_asset_references"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.add_column(
sa.Column("system_metadata", sa.JSON(), nullable=True)
)
batch_op.add_column(
sa.Column("prompt_id", sa.String(length=36), nullable=True)
)
def downgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.drop_column("prompt_id")
batch_op.drop_column("system_metadata")

View File

@ -38,6 +38,7 @@ from app.assets.services import (
update_asset_metadata, update_asset_metadata,
upload_from_temp_path, upload_from_temp_path,
) )
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef() ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None USER_MANAGER: user_manager.UserManager | None = None
@ -122,6 +123,29 @@ def _validate_sort_field(requested: str | None) -> str:
return "created_at" return "created_at"
def _build_asset_response(result) -> schemas_out.Asset:
"""Build an Asset response from a service result."""
preview_url = None
if result.ref.preview_id:
preview_url = f"/api/assets/{result.ref.preview_id}/content?disposition=inline"
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else 0,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
preview_url=preview_url,
preview_id=result.ref.preview_id,
user_metadata=result.ref.user_metadata or {},
metadata=result.ref.system_metadata,
prompt_id=result.ref.prompt_id,
created_at=result.ref.created_at,
updated_at=result.ref.updated_at,
last_access_time=result.ref.last_access_time,
)
@ROUTES.head("/api/assets/hash/{hash}") @ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled @_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response: async def head_asset_by_hash(request: web.Request) -> web.Response:
@ -164,20 +188,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
order=order, order=order,
) )
summaries = [ summaries = [_build_asset_response(item) for item in result.items]
schemas_out.AssetSummary(
id=item.ref.id,
name=item.ref.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes) if item.asset else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.ref.created_at,
updated_at=item.ref.updated_at,
last_access_time=item.ref.last_access_time,
)
for item in result.items
]
payload = schemas_out.AssetsList( payload = schemas_out.AssetsList(
assets=summaries, assets=summaries,
@ -207,18 +218,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
{"id": reference_id}, {"id": reference_id},
) )
payload = schemas_out.AssetDetail( payload = _build_asset_response(result)
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
)
except ValueError as e: except ValueError as e:
return _build_error_response( return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id} 404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
@ -312,29 +312,27 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
400, "INVALID_JSON", "Request body must be valid JSON." 400, "INVALID_JSON", "Request body must be valid JSON."
) )
# Derive name from hash if not provided
name = body.name
if name is None:
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
result = create_from_hash( result = create_from_hash(
hash_str=body.hash, hash_str=body.hash,
name=body.name, name=name,
tags=body.tags, tags=body.tags,
user_metadata=body.user_metadata, user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request), owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
) )
if result is None: if result is None:
return _build_error_response( return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
) )
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated( payload_out = schemas_out.AssetCreated(
id=result.ref.id, **asset.model_dump(),
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new, created_new=result.created_new,
) )
return web.json_response(payload_out.model_dump(mode="json"), status=201) return web.json_response(payload_out.model_dump(mode="json"), status=201)
@ -358,6 +356,9 @@ async def upload_asset(request: web.Request) -> web.Response:
"name": parsed.provided_name, "name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw, "user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash, "hash": parsed.provided_hash,
"id": parsed.provided_id,
"mime_type": parsed.provided_mime_type,
"preview_id": parsed.provided_preview_id,
} }
) )
except ValidationError as ve: except ValidationError as ve:
@ -378,6 +379,21 @@ async def upload_asset(request: web.Request) -> web.Response:
) )
try: try:
# Idempotent create: if spec.id is provided, check if reference already exists
if spec.id:
existing = get_asset_detail(
reference_id=spec.id,
owner_id=owner_id,
)
if existing:
delete_temp_file_if_exists(parsed.tmp_path)
asset = _build_asset_response(existing)
payload_out = schemas_out.AssetCreated(
**asset.model_dump(),
created_new=False,
)
return web.json_response(payload_out.model_dump(mode="json"), status=200)
# Fast path: hash exists, create AssetReference without writing anything # Fast path: hash exists, create AssetReference without writing anything
if spec.hash and parsed.provided_hash_exists is True: if spec.hash and parsed.provided_hash_exists is True:
result = create_from_hash( result = create_from_hash(
@ -386,6 +402,7 @@ async def upload_asset(request: web.Request) -> web.Response:
tags=spec.tags, tags=spec.tags,
user_metadata=spec.user_metadata or {}, user_metadata=spec.user_metadata or {},
owner_id=owner_id, owner_id=owner_id,
mime_type=spec.mime_type,
) )
if result is None: if result is None:
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
@ -410,6 +427,9 @@ async def upload_asset(request: web.Request) -> web.Response:
client_filename=parsed.file_client_name, client_filename=parsed.file_client_name,
owner_id=owner_id, owner_id=owner_id,
expected_hash=spec.hash, expected_hash=spec.hash,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
asset_id=spec.id,
) )
except AssetValidationError as e: except AssetValidationError as e:
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
@ -428,21 +448,13 @@ async def upload_asset(request: web.Request) -> web.Response:
logging.exception("upload_asset failed for owner_id=%s", owner_id) logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
payload = schemas_out.AssetCreated( asset = _build_asset_response(result)
id=result.ref.id, payload_out = schemas_out.AssetCreated(
name=result.ref.name, **asset.model_dump(),
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new, created_new=result.created_new,
) )
status = 201 if result.created_new else 200 status = 201 if result.created_new else 200
return web.json_response(payload.model_dump(mode="json"), status=status) return web.json_response(payload_out.model_dump(mode="json"), status=status)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@ -464,15 +476,10 @@ async def update_asset_route(request: web.Request) -> web.Response:
name=body.name, name=body.name,
user_metadata=body.user_metadata, user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request), owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
preview_id=body.preview_id,
) )
payload = schemas_out.AssetUpdated( payload = _build_asset_response(result)
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
updated_at=result.ref.updated_at,
)
except PermissionError as pe: except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
except ValueError as ve: except ValueError as ve:
@ -587,7 +594,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsAdd( payload = schemas_out.TagsAdd(
added=result.added, added=result.added,
already_present=result.already_present, already_present=result.already_present,
total_tags=result.total_tags, tags=result.total_tags,
) )
except PermissionError as pe: except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
@ -634,7 +641,7 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsRemove( payload = schemas_out.TagsRemove(
removed=result.removed, removed=result.removed,
not_present=result.not_present, not_present=result.not_present,
total_tags=result.total_tags, tags=result.total_tags,
) )
except PermissionError as pe: except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
@ -653,6 +660,28 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.get("/api/assets/tags/refine")
@_require_assets_feature_enabled
async def get_tags_refine(request: web.Request) -> web.Response:
"""GET request to get tag histogram for filtered assets."""
query_dict = get_query_dict(request)
try:
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
except ValidationError as ve:
return _build_validation_error_response("INVALID_QUERY", ve)
tag_counts = list_tag_histogram(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
)
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
return web.json_response(payload.model_dump(mode="json"), status=200)
@ROUTES.post("/api/assets/seed") @ROUTES.post("/api/assets/seed")
@_require_assets_feature_enabled @_require_assets_feature_enabled
async def seed_assets(request: web.Request) -> web.Response: async def seed_assets(request: web.Request) -> web.Response:

View File

@ -45,6 +45,9 @@ class ParsedUpload:
user_metadata_raw: str | None user_metadata_raw: str | None
provided_hash: str | None provided_hash: str | None
provided_hash_exists: bool | None provided_hash_exists: bool | None
provided_id: str | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
class ListAssetsQuery(BaseModel): class ListAssetsQuery(BaseModel):
@ -98,11 +101,18 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel): class UpdateAssetBody(BaseModel):
name: str | None = None name: str | None = None
user_metadata: dict[str, Any] | None = None user_metadata: dict[str, Any] | None = None
mime_type: str | None = None
preview_id: str | None = None
@model_validator(mode="after") @model_validator(mode="after")
def _validate_at_least_one_field(self): def _validate_at_least_one_field(self):
if self.name is None and self.user_metadata is None: if all(
raise ValueError("Provide at least one of: name, user_metadata.") v is None
for v in (self.name, self.user_metadata, self.mime_type, self.preview_id)
):
raise ValueError(
"Provide at least one of: name, user_metadata, mime_type, preview_id."
)
return self return self
@ -110,9 +120,10 @@ class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str hash: str
name: str name: str | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list, min_length=1)
user_metadata: dict[str, Any] = Field(default_factory=dict) user_metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None
@field_validator("hash") @field_validator("hash")
@classmethod @classmethod
@ -138,6 +149,44 @@ class CreateFromHashBody(BaseModel):
return [] return []
class TagsRefineQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel): class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@ -186,21 +235,27 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel): class UploadAssetSpec(BaseModel):
"""Upload Asset operation. """Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output'); - tags: optional list; if provided, first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category if root == 'models', second must be a valid category
- name: display name - name: display name
- user_metadata: arbitrary JSON object (optional) - user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path - hash: optional canonical 'blake3:<hex>' for validation / fast-path
- id: optional UUID for idempotent creation
- mime_type: optional MIME type override
- preview_id: optional asset ID for preview
Files are stored using the content hash as filename stem. Files are stored using the content hash as filename stem.
""" """
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1) tags: list[str] = Field(default_factory=list)
name: str | None = Field(default=None, max_length=512, description="Display Name") name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict) user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None) hash: str | None = Field(default=None)
id: str | None = Field(default=None)
mime_type: str | None = Field(default=None)
preview_id: str | None = Field(default=None)
@field_validator("hash", mode="before") @field_validator("hash", mode="before")
@classmethod @classmethod
@ -278,14 +333,13 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def _validate_order(self): def _validate_order(self):
if not self.tags: if self.tags:
raise ValueError("tags must be provided and non-empty") root = self.tags[0]
root = self.tags[0] if root not in {"models", "input", "output"}:
if root not in {"models", "input", "output"}: raise ValueError("first tag must be one of: models, input, output")
raise ValueError("first tag must be one of: models, input, output") if root == "models":
if root == "models": if len(self.tags) < 2:
if len(self.tags) < 2: raise ValueError(
raise ValueError( "models uploads require a category tag as the second tag"
"models uploads require a category tag as the second tag" )
)
return self return self

View File

@ -4,16 +4,21 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel): class Asset(BaseModel):
id: str id: str
name: str name: str
asset_hash: str | None = None asset_hash: str | None = None
size: int | None = None size: int = 0
mime_type: str | None = None mime_type: str | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
preview_url: str | None = None preview_url: str | None = None
created_at: datetime | None = None preview_id: str | None = None
updated_at: datetime | None = None user_metadata: dict[str, Any] = Field(default_factory=dict)
is_immutable: bool = False
metadata: dict[str, Any] | None = None
prompt_id: str | None = None
created_at: datetime
updated_at: datetime
last_access_time: datetime | None = None last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -23,50 +28,16 @@ class AssetSummary(BaseModel):
return v.isoformat() if v else None return v.isoformat() if v else None
class AssetCreated(Asset):
created_new: bool
class AssetsList(BaseModel): class AssetsList(BaseModel):
assets: list[AssetSummary] assets: list[Asset]
total: int total: int
has_more: bool has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _serialize_updated_at(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel): class TagUsage(BaseModel):
name: str name: str
count: int count: int
@ -83,11 +54,15 @@ class TagsAdd(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True) model_config = ConfigDict(str_strip_whitespace=True)
added: list[str] = Field(default_factory=list) added: list[str] = Field(default_factory=list)
already_present: list[str] = Field(default_factory=list) already_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
class TagsRemove(BaseModel): class TagsRemove(BaseModel):
model_config = ConfigDict(str_strip_whitespace=True) model_config = ConfigDict(str_strip_whitespace=True)
removed: list[str] = Field(default_factory=list) removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list) not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
class TagHistogram(BaseModel):
tag_counts: dict[str, int]

View File

@ -52,6 +52,9 @@ async def parse_multipart_upload(
user_metadata_raw: str | None = None user_metadata_raw: str | None = None
provided_hash: str | None = None provided_hash: str | None = None
provided_hash_exists: bool | None = None provided_hash_exists: bool | None = None
provided_id: str | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
file_written = 0 file_written = 0
tmp_path: str | None = None tmp_path: str | None = None
@ -128,6 +131,12 @@ async def parse_multipart_upload(
provided_name = (await field.text()) or None provided_name = (await field.text()) or None
elif fname == "user_metadata": elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None user_metadata_raw = (await field.text()) or None
elif fname == "id":
provided_id = ((await field.text()) or "").strip() or None
elif fname == "mime_type":
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists): if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError( raise UploadError(
@ -152,6 +161,9 @@ async def parse_multipart_upload(
user_metadata_raw=user_metadata_raw, user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash, provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists, provided_hash_exists=provided_hash_exists,
provided_id=provided_id,
provided_mime_type=provided_mime_type,
provided_preview_id=provided_preview_id,
) )

View File

@ -96,6 +96,10 @@ class AssetReference(Base):
user_metadata: Mapped[dict[str, Any] | None] = mapped_column( user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True) JSON(none_as_null=True)
) )
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True), nullable=True, default=None
)
prompt_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now DateTime(timezone=False), nullable=False, default=get_utc_now
) )

View File

@ -54,6 +54,7 @@ from app.assets.database.queries.tags import (
bulk_insert_tags_and_meta, bulk_insert_tags_and_meta,
ensure_tags_exist, ensure_tags_exist,
get_reference_tags, get_reference_tags,
list_tag_counts_for_filtered_assets,
list_tags_with_usage, list_tags_with_usage,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
remove_tags_from_reference, remove_tags_from_reference,
@ -99,6 +100,7 @@ __all__ = [
"insert_reference", "insert_reference",
"list_references_by_asset_id", "list_references_by_asset_id",
"list_references_page", "list_references_page",
"list_tag_counts_for_filtered_assets",
"list_tags_with_usage", "list_tags_with_usage",
"mark_references_missing_outside_prefixes", "mark_references_missing_outside_prefixes",
"reassign_asset_references", "reassign_asset_references",

View File

@ -24,6 +24,8 @@ from app.assets.database.models import (
) )
from app.assets.database.queries.common import ( from app.assets.database.queries.common import (
MAX_BIND_PARAMS, MAX_BIND_PARAMS,
apply_metadata_filter,
apply_tag_filters,
build_prefix_like_conditions, build_prefix_like_conditions,
build_visible_owner_clause, build_visible_owner_clause,
calculate_rows_per_statement, calculate_rows_per_statement,
@ -79,83 +81,6 @@ def convert_metadata_to_rows(key: str, value) -> list[dict]:
return [{"key": key, "ordinal": 0, "val_json": value}] return [{"key": key, "ordinal": 0, "val_json": value}]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetReferenceMeta.val_json.is_(None),
AssetReferenceMeta.val_str.is_(None),
AssetReferenceMeta.val_num.is_(None),
AssetReferenceMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def get_reference_by_id( def get_reference_by_id(
@ -336,8 +261,8 @@ def list_references_page(
escaped, esc = escape_sql_like_string(name_contains) escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
base = _apply_tag_filters(base, include_tags, exclude_tags) base = apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter) base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower() sort = (sort or "created_at").lower()
order = (order or "desc").lower() order = (order or "desc").lower()
@ -366,8 +291,8 @@ def list_references_page(
count_stmt = count_stmt.where( count_stmt = count_stmt.where(
AssetReference.name.ilike(f"%{escaped}%", escape=esc) AssetReference.name.ilike(f"%{escaped}%", escape=esc)
) )
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int(session.execute(count_stmt).scalar_one() or 0) total = int(session.execute(count_stmt).scalar_one() or 0)
refs = session.execute(base).unique().scalars().all() refs = session.execute(base).unique().scalars().all()

View File

@ -1,12 +1,14 @@
"""Shared utilities for database query modules.""" """Shared utilities for database query modules."""
import os import os
from typing import Iterable from decimal import Decimal
from typing import Iterable, Sequence
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import exists
from app.assets.database.models import AssetReference from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
from app.assets.helpers import escape_sql_like_string from app.assets.helpers import escape_sql_like_string, normalize_tags
MAX_BIND_PARAMS = 800 MAX_BIND_PARAMS = 800
@ -52,3 +54,82 @@ def build_prefix_like_conditions(
escaped, esc = escape_sql_like_string(base) escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
return conds return conds
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetReferenceMeta.val_json.is_(None),
AssetReferenceMeta.val_str.is_(None),
AssetReferenceMeta.val_num.is_(None),
AssetReferenceMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@ -8,12 +8,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.assets.database.models import ( from app.assets.database.models import (
Asset,
AssetReference, AssetReference,
AssetReferenceMeta, AssetReferenceMeta,
AssetReferenceTag, AssetReferenceTag,
Tag, Tag,
) )
from app.assets.database.queries.common import ( from app.assets.database.queries.common import (
apply_metadata_filter,
apply_tag_filters,
build_visible_owner_clause, build_visible_owner_clause,
iter_row_chunks, iter_row_chunks,
) )
@ -320,6 +323,53 @@ def list_tags_with_usage(
return rows_norm, int(total or 0) return rows_norm, int(total or 0)
def list_tag_counts_for_filtered_assets(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
"""Return tag counts for assets matching the given filters.
Uses the same filtering logic as list_references_page but returns
{tag_name: count} instead of paginated references.
"""
# Build a subquery of matching reference IDs
ref_sq = (
select(AssetReference.id)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
ref_sq = ref_sq.subquery()
# Count tags across those references
q = (
select(
AssetReferenceTag.tag_name,
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
.group_by(AssetReferenceTag.tag_name)
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc())
.limit(limit)
)
rows = session.execute(q).all()
return {tag_name: int(cnt) for tag_name, cnt in rows}
def bulk_insert_tags_and_meta( def bulk_insert_tags_and_meta(
session: Session, session: Session,
tag_rows: list[dict], tag_rows: list[dict],

View File

@ -20,6 +20,7 @@ from app.assets.database.queries import (
set_reference_metadata, set_reference_metadata,
set_reference_preview, set_reference_preview,
set_reference_tags, set_reference_tags,
update_asset_hash_and_mime,
update_reference_access_time, update_reference_access_time,
update_reference_name, update_reference_name,
update_reference_updated_at, update_reference_updated_at,
@ -67,6 +68,8 @@ def update_asset_metadata(
user_metadata: UserMetadata = None, user_metadata: UserMetadata = None,
tag_origin: str = "manual", tag_origin: str = "manual",
owner_id: str = "", owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> AssetDetailResult: ) -> AssetDetailResult:
with create_session() as session: with create_session() as session:
ref = get_reference_with_owner_check(session, reference_id, owner_id) ref = get_reference_with_owner_check(session, reference_id, owner_id)
@ -103,6 +106,20 @@ def update_asset_metadata(
) )
touched = True touched = True
if mime_type is not None:
update_asset_hash_and_mime(
session, asset_id=ref.asset_id, mime_type=mime_type
)
touched = True
if preview_id is not None:
set_reference_preview(
session,
reference_id=reference_id,
preview_asset_id=preview_id,
)
touched = True
if touched and user_metadata is None: if touched and user_metadata is None:
update_reference_updated_at(session, reference_id=reference_id) update_reference_updated_at(session, reference_id=reference_id)

View File

@ -18,6 +18,7 @@ from app.assets.database.queries import (
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
set_reference_metadata, set_reference_metadata,
set_reference_tags, set_reference_tags,
update_asset_hash_and_mime,
upsert_asset, upsert_asset,
upsert_reference, upsert_reference,
validate_tags_exist, validate_tags_exist,
@ -242,6 +243,9 @@ def upload_from_temp_path(
client_filename: str | None = None, client_filename: str | None = None,
owner_id: str = "", owner_id: str = "",
expected_hash: str | None = None, expected_hash: str | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
asset_id: str | None = None,
) -> UploadResult: ) -> UploadResult:
try: try:
digest, _ = hashing.compute_blake3_hash(temp_path) digest, _ = hashing.compute_blake3_hash(temp_path)
@ -291,7 +295,7 @@ def upload_from_temp_path(
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir) validate_path_within_base(dest_abs, base_dir)
content_type = ( content_type = mime_type or (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0] or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream" or "application/octet-stream"
@ -315,7 +319,7 @@ def upload_from_temp_path(
mime_type=content_type, mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest), info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id, owner_id=owner_id,
preview_id=None, preview_id=preview_id,
user_metadata=user_metadata or {}, user_metadata=user_metadata or {},
tags=tags, tags=tags,
tag_origin="manual", tag_origin="manual",
@ -348,6 +352,7 @@ def create_from_hash(
tags: list[str] | None = None, tags: list[str] | None = None,
user_metadata: dict | None = None, user_metadata: dict | None = None,
owner_id: str = "", owner_id: str = "",
mime_type: str | None = None,
) -> UploadResult | None: ) -> UploadResult | None:
canonical = hash_str.strip().lower() canonical = hash_str.strip().lower()
@ -356,6 +361,10 @@ def create_from_hash(
if not asset: if not asset:
return None return None
if mime_type and asset.mime_type != mime_type:
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
session.commit()
result = _register_existing_asset( result = _register_existing_asset(
asset_hash=canonical, asset_hash=canonical,
name=_sanitize_filename( name=_sanitize_filename(

View File

@ -23,9 +23,11 @@ class ReferenceData:
file_path: str | None file_path: str | None
user_metadata: UserMetadata user_metadata: UserMetadata
preview_id: str | None preview_id: str | None
created_at: datetime system_metadata: dict[str, Any] | None = None
updated_at: datetime prompt_id: str | None = None
last_access_time: datetime | None created_at: datetime = None # type: ignore[assignment]
updated_at: datetime = None # type: ignore[assignment]
last_access_time: datetime | None = None
@dataclass(frozen=True) @dataclass(frozen=True)
@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
file_path=ref.file_path, file_path=ref.file_path,
user_metadata=ref.user_metadata, user_metadata=ref.user_metadata,
preview_id=ref.preview_id, preview_id=ref.preview_id,
system_metadata=ref.system_metadata,
prompt_id=ref.prompt_id,
created_at=ref.created_at, created_at=ref.created_at,
updated_at=ref.updated_at, updated_at=ref.updated_at,
last_access_time=ref.last_access_time, last_access_time=ref.last_access_time,

View File

@ -1,3 +1,5 @@
from typing import Sequence
from app.assets.database.queries import ( from app.assets.database.queries import (
AddTagsResult, AddTagsResult,
RemoveTagsResult, RemoveTagsResult,
@ -6,6 +8,7 @@ from app.assets.database.queries import (
list_tags_with_usage, list_tags_with_usage,
remove_tags_from_reference, remove_tags_from_reference,
) )
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
from app.assets.services.schemas import TagUsage from app.assets.services.schemas import TagUsage
from app.database.db import create_session from app.database.db import create_session
@ -73,3 +76,23 @@ def list_tags(
) )
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
def list_tag_histogram(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
with create_session() as session:
return list_tag_counts_for_filtered_assets(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
)

View File

@ -97,7 +97,7 @@ def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset
# normalized, deduplicated; 'unit-tests' was already present from the seed # normalized, deduplicated; 'unit-tests' was already present from the seed
assert set(b1["added"]) == {"newtag", "beta"} assert set(b1["added"]) == {"newtag", "beta"}
assert set(b1["already_present"]) == {"unit-tests"} assert set(b1["already_present"]) == {"unit-tests"}
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"] assert "newtag" in b1["tags"] and "beta" in b1["tags"]
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
g = rg.json() g = rg.json()