remove timezone; download asset, delete asset endpoints

This commit is contained in:
bigcat88 2025-08-24 12:01:59 +03:00
parent 8d46bec951
commit 0755e5320a
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
6 changed files with 174 additions and 34 deletions

View File

@ -24,8 +24,8 @@ def upgrade() -> None:
sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"),
sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"), sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"),
sa.Column("storage_locator", sa.Text(), nullable=False), sa.Column("storage_locator", sa.Text(), nullable=False),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"), sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"),
) )
@ -41,9 +41,9 @@ def upgrade() -> None:
sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False), sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True), sa.Column("preview_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True), sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sqlite_autoincrement=True, sqlite_autoincrement=True,
) )
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
@ -68,7 +68,7 @@ def upgrade() -> None:
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_by", sa.String(length=128), nullable=True), sa.Column("added_by", sa.String(length=128), nullable=True),
sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")), sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"), sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
) )
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"]) op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])

View File

@ -1,3 +1,4 @@
import urllib.parse
from typing import Optional from typing import Optional
from aiohttp import web from aiohttp import web
@ -32,6 +33,39 @@ async def list_assets(request: web.Request) -> web.Response:
return web.json_response(payload.model_dump(mode="json")) return web.json_response(payload.model_dump(mode="json"))
@ROUTES.get("/api/assets/{id}/content")
async def download_asset_content(request: web.Request) -> web.Response:
asset_info_id_raw = request.match_info.get("id")
try:
asset_info_id = int(asset_info_id_raw)
except Exception:
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
disposition = request.query.get("disposition", "attachment").lower().strip()
if disposition not in {"inline", "attachment"}:
disposition = "attachment"
try:
abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download(
asset_info_id=asset_info_id
)
except ValueError as ve:
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
except NotImplementedError as nie:
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
except FileNotFoundError:
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
quoted = filename.replace('"', "'")
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
resp = web.FileResponse(abs_path)
resp.content_type = content_type
resp.headers["Content-Disposition"] = cd
return resp
@ROUTES.put("/api/assets/{id}") @ROUTES.put("/api/assets/{id}")
async def update_asset(request: web.Request) -> web.Response: async def update_asset(request: web.Request) -> web.Response:
asset_info_id_raw = request.match_info.get("id") asset_info_id_raw = request.match_info.get("id")
@ -61,6 +95,24 @@ async def update_asset(request: web.Request) -> web.Response:
return web.json_response(result.model_dump(mode="json"), status=200) return web.json_response(result.model_dump(mode="json"), status=200)
@ROUTES.delete("/api/assets/{id}")
async def delete_asset(request: web.Request) -> web.Response:
asset_info_id_raw = request.match_info.get("id")
try:
asset_info_id = int(asset_info_id_raw)
except Exception:
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
try:
deleted = await assets_manager.delete_asset_reference(asset_info_id=asset_info_id)
except Exception:
return _error_response(500, "INTERNAL", "Unexpected server error.")
if not deleted:
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
return web.Response(status=204)
@ROUTES.get("/api/tags") @ROUTES.get("/api/tags")
async def get_tags(request: web.Request) -> web.Response: async def get_tags(request: web.Request) -> web.Response:
query_map = dict(request.rel_url.query) query_map = dict(request.rel_url.query)

View File

@ -1,5 +1,5 @@
import mimetypes
import os import os
from datetime import datetime, timezone
from typing import Optional, Sequence from typing import Optional, Sequence
from comfy.cli_args import args from comfy.cli_args import args
@ -17,6 +17,9 @@ from .database.services import (
list_tags_with_usage, list_tags_with_usage,
add_tags_to_asset_info, add_tags_to_asset_info,
remove_tags_from_asset_info, remove_tags_from_asset_info,
fetch_asset_info_and_asset,
touch_asset_info_by_id,
delete_asset_info_by_id,
) )
from .api import schemas_out from .api import schemas_out
@ -43,7 +46,7 @@ async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> No
async with await create_session() as session: async with await create_session() as session:
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns): if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
await touch_asset_infos_by_fs_path(session, abs_path=abs_path, ts=datetime.now(timezone.utc)) await touch_asset_infos_by_fs_path(session, abs_path=abs_path)
await session.commit() await session.commit()
return return
@ -117,6 +120,40 @@ async def list_assets(
) )
async def resolve_asset_content_for_download(
*, asset_info_id: int
) -> tuple[str, str, str]:
"""
Returns (abs_path, content_type, download_name) for the given AssetInfo id.
Also touches last_access_time (only_if_newer).
Raises:
ValueError if AssetInfo not found
NotImplementedError for unsupported backend
FileNotFoundError if underlying file does not exist (fs backend)
"""
async with await create_session() as session:
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
if asset.storage_backend != "fs":
# Future: support http/s3/gcs/...
raise NotImplementedError(f"backend {asset.storage_backend!r} not supported yet")
abs_path = os.path.abspath(asset.storage_locator)
if not os.path.exists(abs_path):
raise FileNotFoundError(abs_path)
await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
await session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
async def update_asset( async def update_asset(
*, *,
asset_info_id: int, asset_info_id: int,
@ -148,6 +185,12 @@ async def update_asset(
) )
async def delete_asset_reference(*, asset_info_id: int) -> bool:
async with await create_session() as session:
r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id)
await session.commit()
return r
async def list_tags( async def list_tags(
*, *,

View File

@ -14,9 +14,10 @@ from sqlalchemy import (
Numeric, Numeric,
Boolean, Boolean,
) )
from sqlalchemy.sql import func
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
from .timeutil import utcnow
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass pass
@ -46,10 +47,10 @@ class Asset(Base):
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs") storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
storage_locator: Mapped[str] = mapped_column(Text, nullable=False) storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() DateTime(timezone=False), nullable=False, default=utcnow
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() DateTime(timezone=False), nullable=False, default=utcnow
) )
infos: Mapped[list["AssetInfo"]] = relationship( infos: Mapped[list["AssetInfo"]] = relationship(
@ -125,13 +126,13 @@ class AssetInfo(Base):
preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL")) preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON) user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() DateTime(timezone=False), nullable=False, default=utcnow
) )
updated_at: Mapped[datetime] = mapped_column( updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() DateTime(timezone=False), nullable=False, default=utcnow
) )
last_access_time: Mapped[datetime] = mapped_column( last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp() DateTime(timezone=False), nullable=False, default=utcnow
) )
# Relationships # Relationships
@ -221,7 +222,9 @@ class AssetInfoTag(Base):
) )
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_by: Mapped[str | None] = mapped_column(String(128)) added_by: Mapped[str | None] = mapped_column(String(128))
added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False) added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
)
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links") asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links") tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")

View File

@ -1,7 +1,7 @@
import os import os
import logging import logging
from collections import defaultdict from collections import defaultdict
from datetime import datetime, timezone from datetime import datetime
from decimal import Decimal from decimal import Decimal
from typing import Any, Sequence, Optional, Iterable from typing import Any, Sequence, Optional, Iterable
@ -12,6 +12,7 @@ from sqlalchemy.orm import contains_eager
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
from .timeutil import utcnow
async def check_fs_asset_exists_quick( async def check_fs_asset_exists_quick(
@ -93,7 +94,7 @@ async def ingest_fs_asset(
} }
""" """
locator = os.path.abspath(abs_path) locator = os.path.abspath(abs_path)
datetime_now = datetime.now(timezone.utc) datetime_now = utcnow()
out = { out = {
"asset_created": False, "asset_created": False,
@ -246,7 +247,7 @@ async def touch_asset_infos_by_fs_path(
only_if_newer: bool = True, only_if_newer: bool = True,
) -> int: ) -> int:
locator = os.path.abspath(abs_path) locator = os.path.abspath(abs_path)
ts = ts or datetime.now(timezone.utc) ts = ts or utcnow()
stmt = sa.update(AssetInfo).where( stmt = sa.update(AssetInfo).where(
sa.exists( sa.exists(
@ -274,13 +275,31 @@ async def touch_asset_infos_by_fs_path(
return int(res.rowcount or 0) return int(res.rowcount or 0)
async def touch_asset_info_by_id(
session: AsyncSession,
*,
asset_info_id: int,
ts: Optional[datetime] = None,
only_if_newer: bool = True,
) -> int:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
stmt = stmt.values(last_access_time=ts)
res = await session.execute(stmt)
return int(res.rowcount or 0)
async def list_asset_infos_page( async def list_asset_infos_page(
session: AsyncSession, session: AsyncSession,
*, *,
include_tags: Sequence[str] | None = None, include_tags: Optional[Sequence[str]] = None,
exclude_tags: Sequence[str] | None = None, exclude_tags: Optional[Sequence[str]] = None,
name_contains: str | None = None, name_contains: Optional[str] = None,
metadata_filter: dict | None = None, metadata_filter: Optional[dict] = None,
limit: int = 20, limit: int = 20,
offset: int = 0, offset: int = 0,
sort: str = "created_at", sort: str = "created_at",
@ -361,6 +380,19 @@ async def list_asset_infos_page(
return infos, tag_map, total return infos, tag_map, total
async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: int) -> Optional[tuple[AssetInfo, Asset]]:
row = await session.execute(
select(AssetInfo, Asset)
.join(Asset, Asset.hash == AssetInfo.asset_hash)
.where(AssetInfo.id == asset_info_id)
.limit(1)
)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
async def set_asset_info_tags( async def set_asset_info_tags(
session: AsyncSession, session: AsyncSession,
*, *,
@ -374,7 +406,6 @@ async def set_asset_info_tags(
Creates missing tag names as 'user'. Creates missing tag names as 'user'.
""" """
desired = _normalize_tags(tags) desired = _normalize_tags(tags)
now = datetime.now(timezone.utc)
# current links # current links
current = set( current = set(
@ -389,7 +420,7 @@ async def set_asset_info_tags(
if to_add: if to_add:
await _ensure_tags_exist(session, to_add, tag_type="user") await _ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([ session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now) AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=utcnow())
for t in to_add for t in to_add
]) ])
await session.flush() await session.flush()
@ -447,17 +478,23 @@ async def update_asset_info_full(
touched = True touched = True
if touched and user_metadata is None: if touched and user_metadata is None:
info.updated_at = datetime.now(timezone.utc) info.updated_at = utcnow()
await session.flush() await session.flush()
return info return info
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> bool:
"""Delete the user-visible AssetInfo row. Cascades clear tags and metadata."""
res = await session.execute(delete(AssetInfo).where(AssetInfo.id == asset_info_id))
return bool(res.rowcount)
async def replace_asset_info_metadata_projection( async def replace_asset_info_metadata_projection(
session: AsyncSession, session: AsyncSession,
*, *,
asset_info_id: int, asset_info_id: int,
user_metadata: dict | None, user_metadata: Optional[dict],
) -> None: ) -> None:
"""Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`.""" """Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`."""
info = await session.get(AssetInfo, asset_info_id) info = await session.get(AssetInfo, asset_info_id)
@ -465,7 +502,7 @@ async def replace_asset_info_metadata_projection(
raise ValueError(f"AssetInfo {asset_info_id} not found") raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {} info.user_metadata = user_metadata or {}
info.updated_at = datetime.now(timezone.utc) info.updated_at = utcnow()
await session.flush() await session.flush()
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
@ -507,7 +544,7 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[s
async def list_tags_with_usage( async def list_tags_with_usage(
session, session,
*, *,
prefix: str | None = None, prefix: Optional[str] = None,
limit: int = 100, limit: int = 100,
offset: int = 0, offset: int = 0,
include_zero: bool = True, include_zero: bool = True,
@ -611,7 +648,6 @@ async def add_tags_to_asset_info(
already = [t for t in norm if t in existing] already = [t for t in norm if t in existing]
if to_add: if to_add:
now = datetime.now(timezone.utc)
# Make insert race-safe with a nested tx; ignore dup conflicts if any. # Make insert race-safe with a nested tx; ignore dup conflicts if any.
async with session.begin_nested(): async with session.begin_nested():
session.add_all([ session.add_all([
@ -620,7 +656,7 @@ async def add_tags_to_asset_info(
tag_name=t, tag_name=t,
origin=origin, origin=origin,
added_by=added_by, added_by=added_by,
added_at=now, added_at=utcnow(),
) for t in to_add ) for t in to_add
]) ])
try: try:
@ -677,7 +713,7 @@ async def remove_tags_from_asset_info(
return {"removed": to_remove, "not_present": not_present, "total_tags": total} return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def _normalize_tags(tags: Sequence[str] | None) -> list[str]: def _normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
return [t.strip().lower() for t in (tags or []) if (t or "").strip()] return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
@ -697,8 +733,8 @@ async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_ty
def _apply_tag_filters( def _apply_tag_filters(
stmt: sa.sql.Select, stmt: sa.sql.Select,
include_tags: Sequence[str] | None, include_tags: Optional[Sequence[str]],
exclude_tags: Sequence[str] | None, exclude_tags: Optional[Sequence[str]],
) -> sa.sql.Select: ) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present.""" """include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = _normalize_tags(include_tags) include_tags = _normalize_tags(include_tags)
@ -724,7 +760,7 @@ def _apply_tag_filters(
def _apply_metadata_filter( def _apply_metadata_filter(
stmt: sa.sql.Select, stmt: sa.sql.Select,
metadata_filter: dict | None, metadata_filter: Optional[dict],
) -> sa.sql.Select: ) -> sa.sql.Select:
"""Apply metadata filters using the projection table asset_info_meta. """Apply metadata filters using the projection table asset_info_meta.

6
app/database/timeutil.py Normal file
View File

@ -0,0 +1,6 @@
from datetime import datetime, timezone
def utcnow() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)