This commit is contained in:
Luke Mino-Altherr 2026-03-12 17:39:15 -07:00 committed by GitHub
commit 6aa1ba5dff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 1183 additions and 318 deletions

View File

@ -8,7 +8,7 @@ from alembic import context
config = context.config config = context.config
from app.database.models import Base from app.database.models import Base, NAMING_CONVENTION
target_metadata = Base.metadata target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
@ -51,7 +51,10 @@ def run_migrations_online() -> None:
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(
connection=connection, target_metadata=target_metadata connection=connection,
target_metadata=target_metadata,
render_as_batch=True,
naming_convention=NAMING_CONVENTION,
) )
with context.begin_transaction(): with context.begin_transaction():

View File

@ -0,0 +1,66 @@
"""
Add system_metadata and prompt_id columns to asset_references.
Change preview_id FK from assets.id to asset_references.id.
Revision ID: 0003_add_metadata_prompt
Revises: 0002_merge_to_asset_references
Create Date: 2026-03-09
"""
from alembic import op
import sqlalchemy as sa
from app.database.models import NAMING_CONVENTION
revision = "0003_add_metadata_prompt"
down_revision = "0002_merge_to_asset_references"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.add_column(
sa.Column("system_metadata", sa.JSON(), nullable=True)
)
batch_op.add_column(
sa.Column("prompt_id", sa.String(length=36), nullable=True)
)
# Change preview_id FK from assets.id to asset_references.id (self-ref).
# Existing values are asset-content IDs that won't match reference IDs,
# so null them out first.
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_constraint(
"fk_asset_references_preview_id_assets", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_asset_references",
"asset_references",
["preview_id"],
["id"],
ondelete="SET NULL",
)
def downgrade() -> None:
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_constraint(
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_assets",
"assets",
["preview_id"],
["id"],
ondelete="SET NULL",
)
with op.batch_alter_table("asset_references") as batch_op:
batch_op.drop_column("prompt_id")
batch_op.drop_column("system_metadata")

View File

@ -13,6 +13,7 @@ from pydantic import ValidationError
import folder_paths import folder_paths
from app import user_manager from app import user_manager
from app.assets.api import schemas_in, schemas_out from app.assets.api import schemas_in, schemas_out
from app.assets.services import schemas
from app.assets.api.schemas_in import ( from app.assets.api.schemas_in import (
AssetValidationError, AssetValidationError,
UploadError, UploadError,
@ -38,6 +39,7 @@ from app.assets.services import (
update_asset_metadata, update_asset_metadata,
upload_from_temp_path, upload_from_temp_path,
) )
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef() ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None USER_MANAGER: user_manager.UserManager | None = None
@ -122,6 +124,60 @@ def _validate_sort_field(requested: str | None) -> str:
return "created_at" return "created_at"
def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None:
"""Build a /api/view preview URL from asset tags and user_metadata filename."""
if not user_metadata:
return None
filename = user_metadata.get("filename")
if not filename:
return None
if "input" in tags:
view_type = "input"
elif "output" in tags:
view_type = "output"
else:
return None
subfolder = ""
if "/" in filename:
subfolder, filename = filename.rsplit("/", 1)
encoded_filename = urllib.parse.quote(filename, safe="")
url = f"/api/view?type={view_type}&filename={encoded_filename}"
if subfolder:
url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}"
return url
def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
"""Build an Asset response from a service result."""
if result.ref.preview_id:
preview_detail = get_asset_detail(result.ref.preview_id)
if preview_detail:
preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata)
else:
preview_url = None
else:
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
preview_url=preview_url,
preview_id=result.ref.preview_id,
user_metadata=result.ref.user_metadata or {},
metadata=result.ref.system_metadata,
prompt_id=result.ref.prompt_id,
created_at=result.ref.created_at,
updated_at=result.ref.updated_at,
last_access_time=result.ref.last_access_time,
)
@ROUTES.head("/api/assets/hash/{hash}") @ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled @_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response: async def head_asset_by_hash(request: web.Request) -> web.Response:
@ -164,20 +220,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
order=order, order=order,
) )
summaries = [ summaries = [_build_asset_response(item) for item in result.items]
schemas_out.AssetSummary(
id=item.ref.id,
name=item.ref.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes) if item.asset else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.ref.created_at,
updated_at=item.ref.updated_at,
last_access_time=item.ref.last_access_time,
)
for item in result.items
]
payload = schemas_out.AssetsList( payload = schemas_out.AssetsList(
assets=summaries, assets=summaries,
@ -207,18 +250,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
{"id": reference_id}, {"id": reference_id},
) )
payload = schemas_out.AssetDetail( payload = _build_asset_response(result)
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
)
except ValueError as e: except ValueError as e:
return _build_error_response( return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id} 404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
@ -230,7 +262,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
USER_MANAGER.get_request_user_id(request), USER_MANAGER.get_request_user_id(request),
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
@ -312,32 +344,31 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
400, "INVALID_JSON", "Request body must be valid JSON." 400, "INVALID_JSON", "Request body must be valid JSON."
) )
# Derive name from hash if not provided
name = body.name
if name is None:
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
result = create_from_hash( result = create_from_hash(
hash_str=body.hash, hash_str=body.hash,
name=body.name, name=name,
tags=body.tags, tags=body.tags,
user_metadata=body.user_metadata, user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request), owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
preview_id=body.preview_id,
) )
if result is None: if result is None:
return _build_error_response( return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
) )
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated( payload_out = schemas_out.AssetCreated(
id=result.ref.id, **asset.model_dump(),
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new, created_new=result.created_new,
) )
return web.json_response(payload_out.model_dump(mode="json"), status=201) return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201)
@ROUTES.post("/api/assets") @ROUTES.post("/api/assets")
@ -358,6 +389,8 @@ async def upload_asset(request: web.Request) -> web.Response:
"name": parsed.provided_name, "name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw, "user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash, "hash": parsed.provided_hash,
"mime_type": parsed.provided_mime_type,
"preview_id": parsed.provided_preview_id,
} }
) )
except ValidationError as ve: except ValidationError as ve:
@ -386,6 +419,8 @@ async def upload_asset(request: web.Request) -> web.Response:
tags=spec.tags, tags=spec.tags,
user_metadata=spec.user_metadata or {}, user_metadata=spec.user_metadata or {},
owner_id=owner_id, owner_id=owner_id,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
) )
if result is None: if result is None:
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
@ -410,6 +445,8 @@ async def upload_asset(request: web.Request) -> web.Response:
client_filename=parsed.file_client_name, client_filename=parsed.file_client_name,
owner_id=owner_id, owner_id=owner_id,
expected_hash=spec.hash, expected_hash=spec.hash,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
) )
except AssetValidationError as e: except AssetValidationError as e:
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
@ -428,21 +465,13 @@ async def upload_asset(request: web.Request) -> web.Response:
logging.exception("upload_asset failed for owner_id=%s", owner_id) logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
payload = schemas_out.AssetCreated( asset = _build_asset_response(result)
id=result.ref.id, payload_out = schemas_out.AssetCreated(
name=result.ref.name, **asset.model_dump(),
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new, created_new=result.created_new,
) )
status = 201 if result.created_new else 200 status = 201 if result.created_new else 200
return web.json_response(payload.model_dump(mode="json"), status=status) return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@ -464,15 +493,10 @@ async def update_asset_route(request: web.Request) -> web.Response:
name=body.name, name=body.name,
user_metadata=body.user_metadata, user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request), owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
preview_id=body.preview_id,
) )
payload = schemas_out.AssetUpdated( payload = _build_asset_response(result)
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
updated_at=result.ref.updated_at,
)
except PermissionError as pe: except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
except ValueError as ve: except ValueError as ve:
@ -486,7 +510,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
USER_MANAGER.get_request_user_id(request), USER_MANAGER.get_request_user_id(request),
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
@ -555,7 +579,7 @@ async def get_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsList( payload = schemas_out.TagsList(
tags=tags, total=total, has_more=(query.offset + len(tags)) < total tags=tags, total=total, has_more=(query.offset + len(tags)) < total
) )
return web.json_response(payload.model_dump(mode="json")) return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") @ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
@ -603,7 +627,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
@ -650,7 +674,29 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.get("/api/assets/tags/refine")
@_require_assets_feature_enabled
async def get_tags_refine(request: web.Request) -> web.Response:
"""GET request to get tag histogram for filtered assets."""
query_dict = get_query_dict(request)
try:
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
except ValidationError as ve:
return _build_validation_error_response("INVALID_QUERY", ve)
tag_counts = list_tag_histogram(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
)
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.post("/api/assets/seed") @ROUTES.post("/api/assets/seed")

View File

@ -45,6 +45,8 @@ class ParsedUpload:
user_metadata_raw: str | None user_metadata_raw: str | None
provided_hash: str | None provided_hash: str | None
provided_hash_exists: bool | None provided_hash_exists: bool | None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
class ListAssetsQuery(BaseModel): class ListAssetsQuery(BaseModel):
@ -98,11 +100,18 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel): class UpdateAssetBody(BaseModel):
name: str | None = None name: str | None = None
user_metadata: dict[str, Any] | None = None user_metadata: dict[str, Any] | None = None
mime_type: str | None = None
preview_id: str | None = None
@model_validator(mode="after") @model_validator(mode="after")
def _validate_at_least_one_field(self): def _validate_at_least_one_field(self):
if self.name is None and self.user_metadata is None: if all(
raise ValueError("Provide at least one of: name, user_metadata.") v is None
for v in (self.name, self.user_metadata, self.mime_type, self.preview_id)
):
raise ValueError(
"Provide at least one of: name, user_metadata, mime_type, preview_id."
)
return self return self
@ -110,9 +119,11 @@ class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str hash: str
name: str name: str | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict) user_metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None
preview_id: str | None = None
@field_validator("hash") @field_validator("hash")
@classmethod @classmethod
@ -138,6 +149,44 @@ class CreateFromHashBody(BaseModel):
return [] return []
class TagsRefineQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel): class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@ -186,21 +235,25 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel): class UploadAssetSpec(BaseModel):
"""Upload Asset operation. """Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output'); - tags: optional list; if provided, first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category if root == 'models', second must be a valid category
- name: display name - name: display name
- user_metadata: arbitrary JSON object (optional) - user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path - hash: optional canonical 'blake3:<hex>' for validation / fast-path
- mime_type: optional MIME type override
- preview_id: optional asset ID for preview
Files are stored using the content hash as filename stem. Files are stored using the content hash as filename stem.
""" """
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1) tags: list[str] = Field(default_factory=list)
name: str | None = Field(default=None, max_length=512, description="Display Name") name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict) user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None) hash: str | None = Field(default=None)
mime_type: str | None = Field(default=None)
preview_id: str | None = Field(default=None)
@field_validator("hash", mode="before") @field_validator("hash", mode="before")
@classmethod @classmethod
@ -279,7 +332,7 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def _validate_order(self): def _validate_order(self):
if not self.tags: if not self.tags:
raise ValueError("tags must be provided and non-empty") raise ValueError("at least one tag is required for uploads")
root = self.tags[0] root = self.tags[0]
if root not in {"models", "input", "output"}: if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output") raise ValueError("first tag must be one of: models, input, output")

View File

@ -4,7 +4,7 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel): class Asset(BaseModel):
id: str id: str
name: str name: str
asset_hash: str | None = None asset_hash: str | None = None
@ -12,8 +12,13 @@ class AssetSummary(BaseModel):
mime_type: str | None = None mime_type: str | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
preview_url: str | None = None preview_url: str | None = None
created_at: datetime | None = None preview_id: str | None = None
updated_at: datetime | None = None user_metadata: dict[str, Any] = Field(default_factory=dict)
is_immutable: bool = False
metadata: dict[str, Any] | None = None
prompt_id: str | None = None
created_at: datetime
updated_at: datetime
last_access_time: datetime | None = None last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -23,50 +28,16 @@ class AssetSummary(BaseModel):
return v.isoformat() if v else None return v.isoformat() if v else None
class AssetCreated(Asset):
created_new: bool
class AssetsList(BaseModel): class AssetsList(BaseModel):
assets: list[AssetSummary] assets: list[Asset]
total: int total: int
has_more: bool has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _serialize_updated_at(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel): class TagUsage(BaseModel):
name: str name: str
count: int count: int
@ -91,3 +62,7 @@ class TagsRemove(BaseModel):
removed: list[str] = Field(default_factory=list) removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list) not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list) total_tags: list[str] = Field(default_factory=list)
class TagHistogram(BaseModel):
tag_counts: dict[str, int]

View File

@ -52,6 +52,8 @@ async def parse_multipart_upload(
user_metadata_raw: str | None = None user_metadata_raw: str | None = None
provided_hash: str | None = None provided_hash: str | None = None
provided_hash_exists: bool | None = None provided_hash_exists: bool | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
file_written = 0 file_written = 0
tmp_path: str | None = None tmp_path: str | None = None
@ -128,6 +130,16 @@ async def parse_multipart_upload(
provided_name = (await field.text()) or None provided_name = (await field.text()) or None
elif fname == "user_metadata": elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None user_metadata_raw = (await field.text()) or None
elif fname == "id":
raise UploadError(
400,
"UNSUPPORTED_FIELD",
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
)
elif fname == "mime_type":
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists): if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError( raise UploadError(
@ -152,6 +164,8 @@ async def parse_multipart_upload(
user_metadata_raw=user_metadata_raw, user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash, provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists, provided_hash_exists=provided_hash_exists,
provided_mime_type=provided_mime_type,
provided_preview_id=provided_preview_id,
) )

View File

@ -45,13 +45,7 @@ class Asset(Base):
passive_deletes=True, passive_deletes=True,
) )
preview_of: Mapped[list[AssetReference]] = relationship( # preview_id on AssetReference is a self-referential FK to asset_references.id
"AssetReference",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
foreign_keys=lambda: [AssetReference.preview_id],
viewonly=True,
)
__table_args__ = ( __table_args__ = (
Index("uq_assets_hash", "hash", unique=True), Index("uq_assets_hash", "hash", unique=True),
@ -91,11 +85,15 @@ class AssetReference(Base):
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False) name: Mapped[str] = mapped_column(String(512), nullable=False)
preview_id: Mapped[str | None] = mapped_column( preview_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="SET NULL") String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
) )
user_metadata: Mapped[dict[str, Any] | None] = mapped_column( user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True) JSON(none_as_null=True)
) )
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True), nullable=True, default=None
)
prompt_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now DateTime(timezone=False), nullable=False, default=get_utc_now
) )
@ -115,10 +113,10 @@ class AssetReference(Base):
foreign_keys=[asset_id], foreign_keys=[asset_id],
lazy="selectin", lazy="selectin",
) )
preview_asset: Mapped[Asset | None] = relationship( preview_ref: Mapped[AssetReference | None] = relationship(
"Asset", "AssetReference",
back_populates="preview_of",
foreign_keys=[preview_id], foreign_keys=[preview_id],
remote_side=lambda: [AssetReference.id],
) )
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship( metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(

View File

@ -31,16 +31,21 @@ from app.assets.database.queries.asset_reference import (
get_unenriched_references, get_unenriched_references,
get_unreferenced_unhashed_asset_ids, get_unreferenced_unhashed_asset_ids,
insert_reference, insert_reference,
list_all_file_paths_by_asset_id,
list_references_by_asset_id, list_references_by_asset_id,
list_references_page, list_references_page,
mark_references_missing_outside_prefixes, mark_references_missing_outside_prefixes,
rebuild_metadata_projection,
reference_exists,
reference_exists_for_asset_id, reference_exists_for_asset_id,
restore_references_by_paths, restore_references_by_paths,
set_reference_metadata, set_reference_metadata,
set_reference_preview, set_reference_preview,
set_reference_system_metadata,
soft_delete_reference_by_id, soft_delete_reference_by_id,
update_reference_access_time, update_reference_access_time,
update_reference_name, update_reference_name,
update_is_missing_by_asset_id,
update_reference_timestamps, update_reference_timestamps,
update_reference_updated_at, update_reference_updated_at,
upsert_reference, upsert_reference,
@ -54,6 +59,7 @@ from app.assets.database.queries.tags import (
bulk_insert_tags_and_meta, bulk_insert_tags_and_meta,
ensure_tags_exist, ensure_tags_exist,
get_reference_tags, get_reference_tags,
list_tag_counts_for_filtered_assets,
list_tags_with_usage, list_tags_with_usage,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
remove_tags_from_reference, remove_tags_from_reference,
@ -97,20 +103,26 @@ __all__ = [
"get_unenriched_references", "get_unenriched_references",
"get_unreferenced_unhashed_asset_ids", "get_unreferenced_unhashed_asset_ids",
"insert_reference", "insert_reference",
"list_all_file_paths_by_asset_id",
"list_references_by_asset_id", "list_references_by_asset_id",
"list_references_page", "list_references_page",
"list_tag_counts_for_filtered_assets",
"list_tags_with_usage", "list_tags_with_usage",
"mark_references_missing_outside_prefixes", "mark_references_missing_outside_prefixes",
"reassign_asset_references", "reassign_asset_references",
"rebuild_metadata_projection",
"reference_exists",
"reference_exists_for_asset_id", "reference_exists_for_asset_id",
"remove_missing_tag_for_asset_id", "remove_missing_tag_for_asset_id",
"remove_tags_from_reference", "remove_tags_from_reference",
"restore_references_by_paths", "restore_references_by_paths",
"set_reference_metadata", "set_reference_metadata",
"set_reference_preview", "set_reference_preview",
"set_reference_system_metadata",
"soft_delete_reference_by_id", "soft_delete_reference_by_id",
"set_reference_tags", "set_reference_tags",
"update_asset_hash_and_mime", "update_asset_hash_and_mime",
"update_is_missing_by_asset_id",
"update_reference_access_time", "update_reference_access_time",
"update_reference_name", "update_reference_name",
"update_reference_timestamps", "update_reference_timestamps",

View File

@ -69,7 +69,7 @@ def upsert_asset(
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes) asset.size_bytes = int(size_bytes)
changed = True changed = True
if mime_type and asset.mime_type != mime_type: if mime_type and not asset.mime_type:
asset.mime_type = mime_type asset.mime_type = mime_type
changed = True changed = True
if changed: if changed:
@ -118,7 +118,7 @@ def update_asset_hash_and_mime(
return False return False
if asset_hash is not None: if asset_hash is not None:
asset.hash = asset_hash asset.hash = asset_hash
if mime_type is not None: if mime_type is not None and not asset.mime_type:
asset.mime_type = mime_type asset.mime_type = mime_type
return True return True

View File

@ -10,7 +10,7 @@ from decimal import Decimal
from typing import NamedTuple, Sequence from typing import NamedTuple, Sequence
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import delete, exists, select from sqlalchemy import delete, select
from sqlalchemy.dialects import sqlite from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, noload from sqlalchemy.orm import Session, noload
@ -24,12 +24,14 @@ from app.assets.database.models import (
) )
from app.assets.database.queries.common import ( from app.assets.database.queries.common import (
MAX_BIND_PARAMS, MAX_BIND_PARAMS,
apply_metadata_filter,
apply_tag_filters,
build_prefix_like_conditions, build_prefix_like_conditions,
build_visible_owner_clause, build_visible_owner_clause,
calculate_rows_per_statement, calculate_rows_per_statement,
iter_chunks, iter_chunks,
) )
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags from app.assets.helpers import escape_sql_like_string, get_utc_now
def _check_is_scalar(v): def _check_is_scalar(v):
@ -79,83 +81,6 @@ def convert_metadata_to_rows(key: str, value) -> list[dict]:
return [{"key": key, "ordinal": 0, "val_json": value}] return [{"key": key, "ordinal": 0, "val_json": value}]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetReferenceMeta.val_json.is_(None),
AssetReferenceMeta.val_str.is_(None),
AssetReferenceMeta.val_num.is_(None),
AssetReferenceMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def get_reference_by_id( def get_reference_by_id(
@ -212,6 +137,21 @@ def reference_exists_for_asset_id(
return session.execute(q).first() is not None return session.execute(q).first() is not None
def reference_exists(
session: Session,
reference_id: str,
) -> bool:
"""Return True if a reference with the given ID exists (not soft-deleted)."""
q = (
select(sa.literal(True))
.select_from(AssetReference)
.where(AssetReference.id == reference_id)
.where(AssetReference.deleted_at.is_(None))
.limit(1)
)
return session.execute(q).first() is not None
def insert_reference( def insert_reference(
session: Session, session: Session,
asset_id: str, asset_id: str,
@ -336,8 +276,8 @@ def list_references_page(
escaped, esc = escape_sql_like_string(name_contains) escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
base = _apply_tag_filters(base, include_tags, exclude_tags) base = apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter) base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower() sort = (sort or "created_at").lower()
order = (order or "desc").lower() order = (order or "desc").lower()
@ -366,8 +306,8 @@ def list_references_page(
count_stmt = count_stmt.where( count_stmt = count_stmt.where(
AssetReference.name.ilike(f"%{escaped}%", escape=esc) AssetReference.name.ilike(f"%{escaped}%", escape=esc)
) )
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int(session.execute(count_stmt).scalar_one() or 0) total = int(session.execute(count_stmt).scalar_one() or 0)
refs = session.execute(base).unique().scalars().all() refs = session.execute(base).unique().scalars().all()
@ -379,7 +319,7 @@ def list_references_page(
select(AssetReferenceTag.asset_reference_id, Tag.name) select(AssetReferenceTag.asset_reference_id, Tag.name)
.join(Tag, Tag.name == AssetReferenceTag.tag_name) .join(Tag, Tag.name == AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id.in_(id_list)) .where(AssetReferenceTag.asset_reference_id.in_(id_list))
.order_by(AssetReferenceTag.added_at) .order_by(AssetReferenceTag.tag_name.asc())
) )
for ref_id, tag_name in rows.all(): for ref_id, tag_name in rows.all():
tag_map[ref_id].append(tag_name) tag_map[ref_id].append(tag_name)
@ -492,6 +432,42 @@ def update_reference_updated_at(
) )
def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None:
"""Delete and rebuild AssetReferenceMeta rows from merged system+user metadata.
The merged dict is ``{**system_metadata, **user_metadata}`` so user keys
override system keys of the same name.
"""
session.execute(
delete(AssetReferenceMeta).where(
AssetReferenceMeta.asset_reference_id == ref.id
)
)
session.flush()
merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})}
if not merged:
return
rows: list[AssetReferenceMeta] = []
for k, v in merged.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetReferenceMeta(
asset_reference_id=ref.id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def set_reference_metadata( def set_reference_metadata(
session: Session, session: Session,
reference_id: str, reference_id: str,
@ -505,33 +481,24 @@ def set_reference_metadata(
ref.updated_at = get_utc_now() ref.updated_at = get_utc_now()
session.flush() session.flush()
session.execute( rebuild_metadata_projection(session, ref)
delete(AssetReferenceMeta).where(
AssetReferenceMeta.asset_reference_id == reference_id
) def set_reference_system_metadata(
) session: Session,
reference_id: str,
system_metadata: dict | None = None,
) -> None:
"""Set system_metadata on a reference and rebuild the merged projection."""
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
ref.system_metadata = system_metadata or {}
ref.updated_at = get_utc_now()
session.flush() session.flush()
if not user_metadata: rebuild_metadata_projection(session, ref)
return
rows: list[AssetReferenceMeta] = []
for k, v in user_metadata.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetReferenceMeta(
asset_reference_id=reference_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def delete_reference_by_id( def delete_reference_by_id(
@ -571,19 +538,19 @@ def soft_delete_reference_by_id(
def set_reference_preview( def set_reference_preview(
session: Session, session: Session,
reference_id: str, reference_id: str,
preview_asset_id: str | None = None, preview_reference_id: str | None = None,
) -> None: ) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" """Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
ref = session.get(AssetReference, reference_id) ref = session.get(AssetReference, reference_id)
if not ref: if not ref:
raise ValueError(f"AssetReference {reference_id} not found") raise ValueError(f"AssetReference {reference_id} not found")
if preview_asset_id is None: if preview_reference_id is None:
ref.preview_id = None ref.preview_id = None
else: else:
if not session.get(Asset, preview_asset_id): if not session.get(AssetReference, preview_reference_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found") raise ValueError(f"Preview AssetReference {preview_reference_id} not found")
ref.preview_id = preview_asset_id ref.preview_id = preview_reference_id
ref.updated_at = get_utc_now() ref.updated_at = get_utc_now()
session.flush() session.flush()
@ -609,6 +576,8 @@ def list_references_by_asset_id(
session.execute( session.execute(
select(AssetReference) select(AssetReference)
.where(AssetReference.asset_id == asset_id) .where(AssetReference.asset_id == asset_id)
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
.order_by(AssetReference.id.asc()) .order_by(AssetReference.id.asc())
) )
.scalars() .scalars()
@ -616,6 +585,25 @@ def list_references_by_asset_id(
) )
def list_all_file_paths_by_asset_id(
session: Session,
asset_id: str,
) -> list[str]:
"""Return every file_path for an asset, including soft-deleted/missing refs.
Used for orphan cleanup where all on-disk files must be removed.
"""
return list(
session.execute(
select(AssetReference.file_path)
.where(AssetReference.asset_id == asset_id)
.where(AssetReference.file_path.isnot(None))
)
.scalars()
.all()
)
def upsert_reference( def upsert_reference(
session: Session, session: Session,
asset_id: str, asset_id: str,
@ -855,6 +843,22 @@ def bulk_update_is_missing(
return total return total
def update_is_missing_by_asset_id(
session: Session, asset_id: str, value: bool
) -> int:
"""Set is_missing flag for ALL references belonging to an asset.
Returns: Number of rows updated
"""
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.asset_id == asset_id)
.where(AssetReference.deleted_at.is_(None))
.values(is_missing=value)
)
return result.rowcount
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int: def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
"""Delete references by their IDs. """Delete references by their IDs.

View File

@ -1,12 +1,14 @@
"""Shared utilities for database query modules.""" """Shared utilities for database query modules."""
import os import os
from typing import Iterable from decimal import Decimal
from typing import Iterable, Sequence
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import exists
from app.assets.database.models import AssetReference from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
from app.assets.helpers import escape_sql_like_string from app.assets.helpers import escape_sql_like_string, normalize_tags
MAX_BIND_PARAMS = 800 MAX_BIND_PARAMS = 800
@ -52,3 +54,82 @@ def build_prefix_like_conditions(
escaped, esc = escape_sql_like_string(base) escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
return conds return conds
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetReferenceMeta.val_json.is_(None),
AssetReferenceMeta.val_str.is_(None),
AssetReferenceMeta.val_num.is_(None),
AssetReferenceMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@ -8,12 +8,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.assets.database.models import ( from app.assets.database.models import (
Asset,
AssetReference, AssetReference,
AssetReferenceMeta, AssetReferenceMeta,
AssetReferenceTag, AssetReferenceTag,
Tag, Tag,
) )
from app.assets.database.queries.common import ( from app.assets.database.queries.common import (
apply_metadata_filter,
apply_tag_filters,
build_visible_owner_clause, build_visible_owner_clause,
iter_row_chunks, iter_row_chunks,
) )
@ -72,9 +75,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
tag_name tag_name
for (tag_name,) in ( for (tag_name,) in (
session.execute( session.execute(
select(AssetReferenceTag.tag_name).where( select(AssetReferenceTag.tag_name)
AssetReferenceTag.asset_reference_id == reference_id .where(AssetReferenceTag.asset_reference_id == reference_id)
) .order_by(AssetReferenceTag.tag_name.asc())
) )
).all() ).all()
] ]
@ -117,7 +120,7 @@ def set_reference_tags(
) )
session.flush() session.flush()
return SetTagsResult(added=to_add, removed=to_remove, total=desired) return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
def add_tags_to_reference( def add_tags_to_reference(
@ -272,6 +275,12 @@ def list_tags_with_usage(
.select_from(AssetReferenceTag) .select_from(AssetReferenceTag)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id)) .where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None)) .where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name) .group_by(AssetReferenceTag.tag_name)
.subquery() .subquery()
@ -308,6 +317,12 @@ def list_tags_with_usage(
select(AssetReferenceTag.tag_name) select(AssetReferenceTag.tag_name)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id)) .where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None)) .where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name) .group_by(AssetReferenceTag.tag_name)
) )
@ -320,6 +335,53 @@ def list_tags_with_usage(
return rows_norm, int(total or 0) return rows_norm, int(total or 0)
def list_tag_counts_for_filtered_assets(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
"""Return tag counts for assets matching the given filters.
Uses the same filtering logic as list_references_page but returns
{tag_name: count} instead of paginated references.
"""
# Build a subquery of matching reference IDs
ref_sq = (
select(AssetReference.id)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
ref_sq = ref_sq.subquery()
# Count tags across those references
q = (
select(
AssetReferenceTag.tag_name,
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
.group_by(AssetReferenceTag.tag_name)
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
.limit(limit)
)
rows = session.execute(q).all()
return {tag_name: int(cnt) for tag_name, cnt in rows}
def bulk_insert_tags_and_meta( def bulk_insert_tags_and_meta(
session: Session, session: Session,
tag_rows: list[dict], tag_rows: list[dict],

View File

@ -18,7 +18,7 @@ from app.assets.database.queries import (
mark_references_missing_outside_prefixes, mark_references_missing_outside_prefixes,
reassign_asset_references, reassign_asset_references,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
set_reference_metadata, set_reference_system_metadata,
update_asset_hash_and_mime, update_asset_hash_and_mime,
) )
from app.assets.services.bulk_ingest import ( from app.assets.services.bulk_ingest import (
@ -490,8 +490,8 @@ def enrich_asset(
logging.warning("Failed to hash %s: %s", file_path, e) logging.warning("Failed to hash %s: %s", file_path, e)
if extract_metadata and metadata: if extract_metadata and metadata:
user_metadata = metadata.to_user_metadata() system_metadata = metadata.to_user_metadata()
set_reference_metadata(session, reference_id, user_metadata) set_reference_system_metadata(session, reference_id, system_metadata)
if full_hash: if full_hash:
existing = get_asset_by_hash(session, full_hash) existing = get_asset_by_hash(session, full_hash)

View File

@ -16,10 +16,12 @@ from app.assets.database.queries import (
get_reference_by_id, get_reference_by_id,
get_reference_with_owner_check, get_reference_with_owner_check,
list_references_page, list_references_page,
list_all_file_paths_by_asset_id,
list_references_by_asset_id, list_references_by_asset_id,
set_reference_metadata, set_reference_metadata,
set_reference_preview, set_reference_preview,
set_reference_tags, set_reference_tags,
update_asset_hash_and_mime,
update_reference_access_time, update_reference_access_time,
update_reference_name, update_reference_name,
update_reference_updated_at, update_reference_updated_at,
@ -67,6 +69,8 @@ def update_asset_metadata(
user_metadata: UserMetadata = None, user_metadata: UserMetadata = None,
tag_origin: str = "manual", tag_origin: str = "manual",
owner_id: str = "", owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> AssetDetailResult: ) -> AssetDetailResult:
with create_session() as session: with create_session() as session:
ref = get_reference_with_owner_check(session, reference_id, owner_id) ref = get_reference_with_owner_check(session, reference_id, owner_id)
@ -103,6 +107,21 @@ def update_asset_metadata(
) )
touched = True touched = True
if mime_type is not None:
updated = update_asset_hash_and_mime(
session, asset_id=ref.asset_id, mime_type=mime_type
)
if updated:
touched = True
if preview_id is not None:
set_reference_preview(
session,
reference_id=reference_id,
preview_reference_id=preview_id,
)
touched = True
if touched and user_metadata is None: if touched and user_metadata is None:
update_reference_updated_at(session, reference_id=reference_id) update_reference_updated_at(session, reference_id=reference_id)
@ -159,11 +178,9 @@ def delete_asset_reference(
session.commit() session.commit()
return True return True
# Orphaned asset - delete it and its files # Orphaned asset - gather ALL file paths (including
refs = list_references_by_asset_id(session, asset_id=asset_id) # soft-deleted / missing refs) so their on-disk files get cleaned up.
file_paths = [ file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
]
# Also include the just-deleted file path # Also include the just-deleted file path
if file_path: if file_path:
file_paths.append(file_path) file_paths.append(file_path)
@ -185,7 +202,7 @@ def delete_asset_reference(
def set_asset_preview( def set_asset_preview(
reference_id: str, reference_id: str,
preview_asset_id: str | None = None, preview_reference_id: str | None = None,
owner_id: str = "", owner_id: str = "",
) -> AssetDetailResult: ) -> AssetDetailResult:
with create_session() as session: with create_session() as session:
@ -194,7 +211,7 @@ def set_asset_preview(
set_reference_preview( set_reference_preview(
session, session,
reference_id=reference_id, reference_id=reference_id,
preview_asset_id=preview_asset_id, preview_reference_id=preview_reference_id,
) )
result = fetch_reference_asset_and_tags( result = fetch_reference_asset_and_tags(
@ -263,6 +280,47 @@ def list_assets_page(
return ListAssetsResult(items=items, total=total) return ListAssetsResult(items=items, total=total)
def resolve_hash_to_path(
asset_hash: str,
owner_id: str = "",
) -> DownloadResolutionResult | None:
"""Resolve a blake3 hash to an on-disk file path.
Only references visible to *owner_id* are considered (owner-less
references are always visible).
Returns a DownloadResolutionResult with abs_path, content_type, and
download_name, or None if no asset or live path is found.
"""
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash)
if not asset:
return None
refs = list_references_by_asset_id(session, asset_id=asset.id)
visible = [
r for r in refs
if r.owner_id == "" or r.owner_id == owner_id
]
abs_path = select_best_live_path(visible)
if not abs_path:
return None
display_name = os.path.basename(abs_path)
for ref in visible:
if ref.file_path == abs_path and ref.name:
display_name = ref.name
break
ctype = (
asset.mime_type
or mimetypes.guess_type(display_name)[0]
or "application/octet-stream"
)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=display_name,
)
def resolve_asset_for_download( def resolve_asset_for_download(
reference_id: str, reference_id: str,
owner_id: str = "", owner_id: str = "",

View File

@ -11,13 +11,14 @@ from app.assets.database.queries import (
add_tags_to_reference, add_tags_to_reference,
fetch_reference_and_asset, fetch_reference_and_asset,
get_asset_by_hash, get_asset_by_hash,
get_existing_asset_ids,
get_reference_by_file_path, get_reference_by_file_path,
get_reference_tags, get_reference_tags,
get_or_create_reference, get_or_create_reference,
reference_exists,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
set_reference_metadata, set_reference_metadata,
set_reference_tags, set_reference_tags,
update_asset_hash_and_mime,
upsert_asset, upsert_asset,
upsert_reference, upsert_reference,
validate_tags_exist, validate_tags_exist,
@ -26,6 +27,7 @@ from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import ( from app.assets.services.path_utils import (
compute_relative_filename, compute_relative_filename,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags, resolve_destination_from_tags,
validate_path_within_base, validate_path_within_base,
) )
@ -65,7 +67,7 @@ def _ingest_file_from_path(
with create_session() as session: with create_session() as session:
if preview_id: if preview_id:
if preview_id not in get_existing_asset_ids(session, [preview_id]): if not reference_exists(session, preview_id):
preview_id = None preview_id = None
asset, asset_created, asset_updated = upsert_asset( asset, asset_created, asset_updated = upsert_asset(
@ -135,6 +137,8 @@ def _register_existing_asset(
tags: list[str] | None = None, tags: list[str] | None = None,
tag_origin: str = "manual", tag_origin: str = "manual",
owner_id: str = "", owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> RegisterAssetResult: ) -> RegisterAssetResult:
user_metadata = user_metadata or {} user_metadata = user_metadata or {}
@ -143,14 +147,25 @@ def _register_existing_asset(
if not asset: if not asset:
raise ValueError(f"No asset with hash {asset_hash}") raise ValueError(f"No asset with hash {asset_hash}")
if mime_type and not asset.mime_type:
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
if preview_id:
if not reference_exists(session, preview_id):
preview_id = None
ref, ref_created = get_or_create_reference( ref, ref_created = get_or_create_reference(
session, session,
asset_id=asset.id, asset_id=asset.id,
owner_id=owner_id, owner_id=owner_id,
name=name, name=name,
preview_id=preview_id,
) )
if not ref_created: if not ref_created:
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
tag_names = get_reference_tags(session, reference_id=ref.id) tag_names = get_reference_tags(session, reference_id=ref.id)
result = RegisterAssetResult( result = RegisterAssetResult(
ref=extract_reference_data(ref), ref=extract_reference_data(ref),
@ -242,6 +257,8 @@ def upload_from_temp_path(
client_filename: str | None = None, client_filename: str | None = None,
owner_id: str = "", owner_id: str = "",
expected_hash: str | None = None, expected_hash: str | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult: ) -> UploadResult:
try: try:
digest, _ = hashing.compute_blake3_hash(temp_path) digest, _ = hashing.compute_blake3_hash(temp_path)
@ -270,6 +287,8 @@ def upload_from_temp_path(
tags=tags or [], tags=tags or [],
tag_origin="manual", tag_origin="manual",
owner_id=owner_id, owner_id=owner_id,
mime_type=mime_type,
preview_id=preview_id,
) )
return UploadResult( return UploadResult(
ref=result.ref, ref=result.ref,
@ -291,7 +310,7 @@ def upload_from_temp_path(
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir) validate_path_within_base(dest_abs, base_dir)
content_type = ( content_type = mime_type or (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0] or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream" or "application/octet-stream"
@ -315,7 +334,7 @@ def upload_from_temp_path(
mime_type=content_type, mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest), info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id, owner_id=owner_id,
preview_id=None, preview_id=preview_id,
user_metadata=user_metadata or {}, user_metadata=user_metadata or {},
tags=tags, tags=tags,
tag_origin="manual", tag_origin="manual",
@ -342,30 +361,99 @@ def upload_from_temp_path(
) )
def register_file_in_place(
abs_path: str,
name: str,
tags: list[str],
owner_id: str = "",
mime_type: str | None = None,
) -> UploadResult:
"""Register an already-saved file in the asset database without moving it.
Tags are derived from the filesystem path (root category + subfolder names),
merged with any caller-provided tags, matching the behavior of the scanner.
If the path is not under a known root, only the caller-provided tags are used.
"""
try:
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
except ValueError:
path_tags = []
merged_tags = normalize_tags([*path_tags, *tags])
try:
digest, _ = hashing.compute_blake3_hash(abs_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash file: {e}")
asset_hash = "blake3:" + digest
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
content_type = mime_type or (
mimetypes.guess_type(abs_path, strict=False)[0]
or "application/octet-stream"
)
ingest_result = _ingest_file_from_path(
abs_path=abs_path,
asset_hash=asset_hash,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name, fallback=digest),
owner_id=owner_id,
tags=merged_tags,
tag_origin="upload",
require_existing_tags=False,
)
reference_id = ingest_result.reference_id
if not reference_id:
raise RuntimeError("failed to create asset reference")
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
ref, asset = pair
tag_names = get_reference_tags(session, reference_id=ref.id)
return UploadResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def create_from_hash( def create_from_hash(
hash_str: str, hash_str: str,
name: str, name: str,
tags: list[str] | None = None, tags: list[str] | None = None,
user_metadata: dict | None = None, user_metadata: dict | None = None,
owner_id: str = "", owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult | None: ) -> UploadResult | None:
canonical = hash_str.strip().lower() canonical = hash_str.strip().lower()
with create_session() as session: try:
asset = get_asset_by_hash(session, asset_hash=canonical) result = _register_existing_asset(
if not asset: asset_hash=canonical,
return None name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
result = _register_existing_asset( ),
asset_hash=canonical, user_metadata=user_metadata or {},
name=_sanitize_filename( tags=tags or [],
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical tag_origin="manual",
), owner_id=owner_id,
user_metadata=user_metadata or {}, mime_type=mime_type,
tags=tags or [], preview_id=preview_id,
tag_origin="manual", )
owner_id=owner_id, except ValueError:
) logging.warning("create_from_hash: no asset found for hash %s", canonical)
return None
return UploadResult( return UploadResult(
ref=result.ref, ref=result.ref,

View File

@ -25,7 +25,9 @@ class ReferenceData:
preview_id: str | None preview_id: str | None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
last_access_time: datetime | None system_metadata: dict[str, Any] | None = None
prompt_id: str | None = None
last_access_time: datetime | None = None
@dataclass(frozen=True) @dataclass(frozen=True)
@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
file_path=ref.file_path, file_path=ref.file_path,
user_metadata=ref.user_metadata, user_metadata=ref.user_metadata,
preview_id=ref.preview_id, preview_id=ref.preview_id,
system_metadata=ref.system_metadata,
prompt_id=ref.prompt_id,
created_at=ref.created_at, created_at=ref.created_at,
updated_at=ref.updated_at, updated_at=ref.updated_at,
last_access_time=ref.last_access_time, last_access_time=ref.last_access_time,

View File

@ -1,3 +1,5 @@
from typing import Sequence
from app.assets.database.queries import ( from app.assets.database.queries import (
AddTagsResult, AddTagsResult,
RemoveTagsResult, RemoveTagsResult,
@ -6,6 +8,7 @@ from app.assets.database.queries import (
list_tags_with_usage, list_tags_with_usage,
remove_tags_from_reference, remove_tags_from_reference,
) )
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
from app.assets.services.schemas import TagUsage from app.assets.services.schemas import TagUsage
from app.database.db import create_session from app.database.db import create_session
@ -73,3 +76,23 @@ def list_tags(
) )
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
def list_tag_histogram(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
with create_session() as session:
return list_tag_counts_for_filtered_assets(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
)

View File

@ -1,9 +1,18 @@
from typing import Any from typing import Any
from datetime import datetime from datetime import datetime
from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
NAMING_CONVENTION = {
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass metadata = MetaData(naming_convention=NAMING_CONVENTION)
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys() fields = obj.__table__.columns.keys()

View File

@ -35,6 +35,8 @@ from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes from app.assets.api.routes import register_assets_routes
from app.assets.services.ingest import register_file_in_place
from app.assets.services.asset_management import resolve_hash_to_path
from app.user_manager import UserManager from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
@ -419,7 +421,24 @@ class PromptServer():
with open(filepath, "wb") as f: with open(filepath, "wb") as f:
f.write(image.file.read()) f.write(image.file.read())
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
if args.enable_assets:
try:
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
resp["asset"] = {
"id": result.ref.id,
"name": result.ref.name,
"asset_hash": result.asset.hash if result.asset else None,
"size": result.asset.size_bytes if result.asset else None,
"mime_type": result.asset.mime_type if result.asset else None,
"tags": result.tags,
}
except Exception:
logging.warning("Failed to register uploaded image as asset", exc_info=True)
return web.json_response(resp)
else: else:
return web.Response(status=400) return web.Response(status=400)
@ -479,30 +498,43 @@ class PromptServer():
async def view_image(request): async def view_image(request):
if "filename" in request.rel_url.query: if "filename" in request.rel_url.query:
filename = request.rel_url.query["filename"] filename = request.rel_url.query["filename"]
filename, output_dir = folder_paths.annotated_filepath(filename)
if not filename: # The frontend's LoadImage combo widget uses asset_hash values
return web.Response(status=400) # (e.g. "blake3:...") as widget values. When litegraph renders the
# node preview, it constructs /view?filename=<asset_hash>, so this
# endpoint must resolve blake3 hashes to their on-disk file paths.
if filename.startswith("blake3:"):
owner_id = self.user_manager.get_request_user_id(request)
result = resolve_hash_to_path(filename, owner_id=owner_id)
if result is None:
return web.Response(status=404)
file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type
else:
resolved_content_type = None
filename, output_dir = folder_paths.annotated_filepath(filename)
# validation for security: prevent accessing arbitrary path if not filename:
if filename[0] == '/' or '..' in filename: return web.Response(status=400)
return web.Response(status=400)
if output_dir is None: # validation for security: prevent accessing arbitrary path
type = request.rel_url.query.get("type", "output") if filename[0] == '/' or '..' in filename:
output_dir = folder_paths.get_directory_by_type(type) return web.Response(status=400)
if output_dir is None: if output_dir is None:
return web.Response(status=400) type = request.rel_url.query.get("type", "output")
output_dir = folder_paths.get_directory_by_type(type)
if "subfolder" in request.rel_url.query: if output_dir is None:
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) return web.Response(status=400)
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
return web.Response(status=403)
output_dir = full_output_dir
filename = os.path.basename(filename) if "subfolder" in request.rel_url.query:
file = os.path.join(output_dir, filename) full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
return web.Response(status=403)
output_dir = full_output_dir
filename = os.path.basename(filename)
file = os.path.join(output_dir, filename)
if os.path.isfile(file): if os.path.isfile(file):
if 'preview' in request.rel_url.query: if 'preview' in request.rel_url.query:
@ -562,8 +594,13 @@ class PromptServer():
return web.Response(body=alpha_buffer.read(), content_type='image/png', return web.Response(body=alpha_buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""}) headers={"Content-Disposition": f"filename=\"{filename}\""})
else: else:
# Get content type from mimetype, defaulting to 'application/octet-stream' # Use the content type from asset resolution if available,
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream' # otherwise guess from the filename.
content_type = (
resolved_content_type
or mimetypes.guess_type(filename)[0]
or 'application/octet-stream'
)
# For security, force certain mimetypes to download instead of display # For security, force certain mimetypes to download instead of display
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}: if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:

View File

@ -0,0 +1,57 @@
"""Test that Alembic migrations run cleanly on a file-backed SQLite DB.
This catches problems like unnamed FK constraints that prevent batch-mode
drop_constraint from working on real SQLite files (see MB-2).
Migrations 0001 and 0002 are already shipped, so we only exercise
upgrade/downgrade for 0003+.
"""
import os
import pytest
from alembic import command
from alembic.config import Config
# Oldest shipped revision — we upgrade to here as a baseline and never
# downgrade past it.
_BASELINE = "0002_merge_to_asset_references"
def _make_config(db_path: str) -> Config:
root = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root, "alembic_db"))
cfg = Config(config_path)
cfg.set_main_option("script_location", scripts_path)
cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
return cfg
@pytest.fixture
def migration_db(tmp_path):
"""Yield an alembic Config pre-upgraded to the baseline revision."""
db_path = str(tmp_path / "test_migration.db")
cfg = _make_config(db_path)
command.upgrade(cfg, _BASELINE)
yield cfg
def test_upgrade_to_head(migration_db):
"""Upgrade from baseline to head must succeed on a file-backed DB."""
command.upgrade(migration_db, "head")
def test_downgrade_to_baseline(migration_db):
"""Upgrade to head then downgrade back to baseline."""
command.upgrade(migration_db, "head")
command.downgrade(migration_db, _BASELINE)
def test_upgrade_downgrade_cycle(migration_db):
"""Full cycle: upgrade → downgrade → upgrade again."""
command.upgrade(migration_db, "head")
command.downgrade(migration_db, _BASELINE)
command.upgrade(migration_db, "head")

View File

@ -10,6 +10,7 @@ from app.assets.database.queries import (
get_asset_by_hash, get_asset_by_hash,
upsert_asset, upsert_asset,
bulk_insert_assets, bulk_insert_assets,
update_asset_hash_and_mime,
) )
@ -142,3 +143,45 @@ class TestBulkInsertAssets:
session.commit() session.commit()
assert session.query(Asset).count() == 200 assert session.query(Asset).count() == 200
class TestMimeTypeImmutability:
"""mime_type on Asset is write-once: set on first ingest, never overwritten."""
@pytest.mark.parametrize(
"initial_mime,second_mime,expected_mime",
[
("image/png", "image/jpeg", "image/png"),
(None, "image/png", "image/png"),
],
ids=["preserves_existing", "fills_null"],
)
def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime):
h = f"blake3:upsert_{initial_mime}_{second_mime}"
upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime)
session.commit()
asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime)
assert created is False
assert asset.mime_type == expected_mime
@pytest.mark.parametrize(
"initial_mime,update_mime,update_hash,expected_mime,expected_hash",
[
(None, "image/png", None, "image/png", "blake3:upd0"),
("image/png", "image/jpeg", None, "image/png", "blake3:upd1"),
("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"),
],
ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"],
)
def test_update_asset_hash_and_mime_immutability(
self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash,
):
h = expected_hash.removesuffix("_new")
asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime)
session.add(asset)
session.flush()
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash)
assert asset.mime_type == expected_mime
assert asset.hash == expected_hash

View File

@ -242,22 +242,24 @@ class TestSetReferencePreview:
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash") preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
session.commit() session.commit()
set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id) set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id)
session.commit() session.commit()
session.refresh(ref) session.refresh(ref)
assert ref.preview_id == preview_asset.id assert ref.preview_id == preview_ref.id
def test_clears_preview(self, session: Session): def test_clears_preview(self, session: Session):
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash") preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
ref.preview_id = preview_asset.id preview_ref = _make_reference(session, preview_asset, name="preview.png")
ref.preview_id = preview_ref.id
session.commit() session.commit()
set_reference_preview(session, reference_id=ref.id, preview_asset_id=None) set_reference_preview(session, reference_id=ref.id, preview_reference_id=None)
session.commit() session.commit()
session.refresh(ref) session.refresh(ref)
@ -265,15 +267,15 @@ class TestSetReferencePreview:
def test_raises_for_nonexistent_reference(self, session: Session): def test_raises_for_nonexistent_reference(self, session: Session):
with pytest.raises(ValueError, match="not found"): with pytest.raises(ValueError, match="not found"):
set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None) set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None)
def test_raises_for_nonexistent_preview(self, session: Session): def test_raises_for_nonexistent_preview(self, session: Session):
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
session.commit() session.commit()
with pytest.raises(ValueError, match="Preview Asset"): with pytest.raises(ValueError, match="Preview AssetReference"):
set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent") set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent")
class TestInsertReference: class TestInsertReference:
@ -351,13 +353,14 @@ class TestUpdateReferenceTimestamps:
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash") preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
session.commit() session.commit()
update_reference_timestamps(session, ref, preview_id=preview_asset.id) update_reference_timestamps(session, ref, preview_id=preview_ref.id)
session.commit() session.commit()
session.refresh(ref) session.refresh(ref)
assert ref.preview_id == preview_asset.id assert ref.preview_id == preview_ref.id
class TestSetReferenceMetadata: class TestSetReferenceMetadata:

View File

@ -20,6 +20,7 @@ def _make_reference(
asset: Asset, asset: Asset,
name: str, name: str,
metadata: dict | None = None, metadata: dict | None = None,
system_metadata: dict | None = None,
) -> AssetReference: ) -> AssetReference:
now = get_utc_now() now = get_utc_now()
ref = AssetReference( ref = AssetReference(
@ -27,6 +28,7 @@ def _make_reference(
name=name, name=name,
asset_id=asset.id, asset_id=asset.id,
user_metadata=metadata, user_metadata=metadata,
system_metadata=system_metadata,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
last_access_time=now, last_access_time=now,
@ -34,8 +36,10 @@ def _make_reference(
session.add(ref) session.add(ref)
session.flush() session.flush()
if metadata: # Build merged projection: {**system_metadata, **user_metadata}
for key, val in metadata.items(): merged = {**(system_metadata or {}), **(metadata or {})}
if merged:
for key, val in merged.items():
for row in convert_metadata_to_rows(key, val): for row in convert_metadata_to_rows(key, val):
meta_row = AssetReferenceMeta( meta_row = AssetReferenceMeta(
asset_reference_id=ref.id, asset_reference_id=ref.id,
@ -182,3 +186,46 @@ class TestMetadataFilterEmptyDict:
refs, _, total = list_references_page(session, metadata_filter={}) refs, _, total = list_references_page(session, metadata_filter={})
assert total == 2 assert total == 2
class TestSystemMetadataProjection:
"""Tests for system_metadata merging into the filter projection."""
def test_system_metadata_keys_are_filterable(self, session: Session):
"""system_metadata keys should appear in the merged projection."""
asset = _make_asset(session, "hash1")
_make_reference(
session, asset, "with_sys",
system_metadata={"source": "scanner"},
)
_make_reference(session, asset, "without_sys")
session.commit()
refs, _, total = list_references_page(
session, metadata_filter={"source": "scanner"}
)
assert total == 1
assert refs[0].name == "with_sys"
def test_user_metadata_overrides_system_metadata(self, session: Session):
"""user_metadata should win when both have the same key."""
asset = _make_asset(session, "hash1")
_make_reference(
session, asset, "overridden",
metadata={"origin": "user_upload"},
system_metadata={"origin": "auto_scan"},
)
session.commit()
# Should match the user value, not the system value
refs, _, total = list_references_page(
session, metadata_filter={"origin": "user_upload"}
)
assert total == 1
assert refs[0].name == "overridden"
# Should NOT match the system value (it was overridden)
refs, _, total = list_references_page(
session, metadata_filter={"origin": "auto_scan"}
)
assert total == 0

View File

@ -11,6 +11,7 @@ from app.assets.services import (
delete_asset_reference, delete_asset_reference,
set_asset_preview, set_asset_preview,
) )
from app.assets.services.asset_management import resolve_hash_to_path
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset: def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
@ -219,31 +220,33 @@ class TestSetAssetPreview:
asset = _make_asset(session, hash_val="blake3:main") asset = _make_asset(session, hash_val="blake3:main")
preview_asset = _make_asset(session, hash_val="blake3:preview") preview_asset = _make_asset(session, hash_val="blake3:preview")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
ref_id = ref.id ref_id = ref.id
preview_id = preview_asset.id preview_ref_id = preview_ref.id
session.commit() session.commit()
set_asset_preview( set_asset_preview(
reference_id=ref_id, reference_id=ref_id,
preview_asset_id=preview_id, preview_reference_id=preview_ref_id,
) )
# Verify by re-fetching from DB # Verify by re-fetching from DB
session.expire_all() session.expire_all()
updated_ref = session.get(AssetReference, ref_id) updated_ref = session.get(AssetReference, ref_id)
assert updated_ref.preview_id == preview_id assert updated_ref.preview_id == preview_ref_id
def test_clears_preview(self, mock_create_session, session: Session): def test_clears_preview(self, mock_create_session, session: Session):
asset = _make_asset(session) asset = _make_asset(session)
preview_asset = _make_asset(session, hash_val="blake3:preview") preview_asset = _make_asset(session, hash_val="blake3:preview")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
ref.preview_id = preview_asset.id preview_ref = _make_reference(session, preview_asset, name="preview.png")
ref.preview_id = preview_ref.id
ref_id = ref.id ref_id = ref.id
session.commit() session.commit()
set_asset_preview( set_asset_preview(
reference_id=ref_id, reference_id=ref_id,
preview_asset_id=None, preview_reference_id=None,
) )
# Verify by re-fetching from DB # Verify by re-fetching from DB
@ -263,6 +266,45 @@ class TestSetAssetPreview:
with pytest.raises(PermissionError, match="not owner"): with pytest.raises(PermissionError, match="not owner"):
set_asset_preview( set_asset_preview(
reference_id=ref.id, reference_id=ref.id,
preview_asset_id=None, preview_reference_id=None,
owner_id="user2", owner_id="user2",
) )
class TestResolveHashToPath:
def test_returns_none_for_unknown_hash(self, mock_create_session):
result = resolve_hash_to_path("blake3:" + "a" * 64)
assert result is None
@pytest.mark.parametrize(
"ref_owner, query_owner, expect_found",
[
("user1", "user1", True),
("user1", "user2", False),
("", "anyone", True),
("", "", True),
],
ids=[
"owner_sees_own_ref",
"other_owner_blocked",
"ownerless_visible_to_anyone",
"ownerless_visible_to_empty",
],
)
def test_owner_visibility(
self, ref_owner, query_owner, expect_found,
mock_create_session, session: Session, temp_dir,
):
f = temp_dir / "file.bin"
f.write_bytes(b"data")
asset = _make_asset(session, hash_val="blake3:" + "b" * 64)
ref = _make_reference(session, asset, name="file.bin", owner_id=ref_owner)
ref.file_path = str(f)
session.commit()
result = resolve_hash_to_path(asset.hash, owner_id=query_owner)
if expect_found:
assert result is not None
assert result.abs_path == str(f)
else:
assert result is None

View File

@ -113,11 +113,19 @@ class TestIngestFileFromPath:
file_path = temp_dir / "with_preview.bin" file_path = temp_dir / "with_preview.bin"
file_path.write_bytes(b"data") file_path.write_bytes(b"data")
# Create a preview asset first # Create a preview asset and reference
preview_asset = Asset(hash="blake3:preview", size_bytes=100) preview_asset = Asset(hash="blake3:preview", size_bytes=100)
session.add(preview_asset) session.add(preview_asset)
session.flush()
from app.assets.helpers import get_utc_now
now = get_utc_now()
preview_ref = AssetReference(
asset_id=preview_asset.id, name="preview.png", owner_id="",
created_at=now, updated_at=now, last_access_time=now,
)
session.add(preview_ref)
session.commit() session.commit()
preview_id = preview_asset.id preview_id = preview_ref.id
result = _ingest_file_from_path( result = _ingest_file_from_path(
abs_path=str(file_path), abs_path=str(file_path),

View File

@ -0,0 +1,123 @@
"""Tests for list_tag_histogram service function."""
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference
from app.assets.helpers import get_utc_now
from app.assets.services.tagging import list_tag_histogram
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_reference(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
class TestListTagHistogram:
def test_returns_counts_for_all_tags(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["alpha", "beta"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="r1")
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha", "beta"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="r2")
add_tags_to_reference(session, reference_id=r2.id, tags=["alpha"])
session.commit()
result = list_tag_histogram()
assert result["alpha"] == 2
assert result["beta"] == 1
def test_empty_when_no_assets(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["unused"])
session.commit()
result = list_tag_histogram()
assert result == {}
def test_include_tags_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["models", "loras", "input"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="r1")
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="r2")
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
session.commit()
result = list_tag_histogram(include_tags=["models"])
# Only r1 has "models", so only its tags appear
assert "models" in result
assert "loras" in result
assert "input" not in result
def test_exclude_tags_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["models", "loras", "input"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="r1")
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="r2")
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
session.commit()
result = list_tag_histogram(exclude_tags=["models"])
# r1 excluded, only r2's tags remain
assert "input" in result
assert "loras" not in result
def test_name_contains_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["alpha", "beta"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="my_model.safetensors")
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="picture.png")
add_tags_to_reference(session, reference_id=r2.id, tags=["beta"])
session.commit()
result = list_tag_histogram(name_contains="model")
assert "alpha" in result
assert "beta" not in result
def test_limit_caps_results(self, mock_create_session, session: Session):
tags = [f"tag{i}" for i in range(10)]
ensure_tags_exist(session, tags)
a = _make_asset(session, "blake3:aaa")
r = _make_reference(session, a, name="r1")
add_tags_to_reference(session, reference_id=r.id, tags=tags)
session.commit()
result = list_tag_histogram(limit=3)
assert len(result) == 3

View File

@ -243,6 +243,15 @@ def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):
files = {"file": ("notags.bin", b"A" * 64, "application/octet-stream")}
form = {"tags": json.dumps([]), "name": "notags.bin", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.parametrize("root", ["input", "output"]) @pytest.mark.parametrize("root", ["input", "output"])
def test_duplicate_upload_same_display_name_does_not_clobber( def test_duplicate_upload_same_display_name_does_not_clobber(
root: str, root: str,