mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-19 11:03:00 +08:00
remove timezone; download asset, delete asset endpoints
This commit is contained in:
parent
8d46bec951
commit
0755e5320a
@ -24,8 +24,8 @@ def upgrade() -> None:
|
||||
sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"),
|
||||
sa.Column("storage_backend", sa.String(length=32), nullable=False, server_default="fs"),
|
||||
sa.Column("storage_locator", sa.Text(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
sa.CheckConstraint("refcount >= 0", name="ck_assets_refcount_nonneg"),
|
||||
)
|
||||
@ -41,9 +41,9 @@ def upgrade() -> None:
|
||||
sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sqlite_autoincrement=True,
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
@ -68,7 +68,7 @@ def upgrade() -> None:
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_by", sa.String(length=128), nullable=True),
|
||||
sa.Column("added_at", sa.DateTime(timezone=True), nullable=False, server_default=sa.text("CURRENT_TIMESTAMP")),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
|
||||
@ -1,3 +1,4 @@
|
||||
import urllib.parse
|
||||
from typing import Optional
|
||||
|
||||
from aiohttp import web
|
||||
@ -32,6 +33,39 @@ async def list_assets(request: web.Request) -> web.Response:
|
||||
return web.json_response(payload.model_dump(mode="json"))
|
||||
|
||||
|
||||
|
||||
@ROUTES.get("/api/assets/{id}/content")
|
||||
async def download_asset_content(request: web.Request) -> web.Response:
|
||||
asset_info_id_raw = request.match_info.get("id")
|
||||
try:
|
||||
asset_info_id = int(asset_info_id_raw)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||
|
||||
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||
if disposition not in {"inline", "attachment"}:
|
||||
disposition = "attachment"
|
||||
|
||||
try:
|
||||
abs_path, content_type, filename = await assets_manager.resolve_asset_content_for_download(
|
||||
asset_info_id=asset_info_id
|
||||
)
|
||||
except ValueError as ve:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||
except NotImplementedError as nie:
|
||||
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||
except FileNotFoundError:
|
||||
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||
|
||||
quoted = filename.replace('"', "'")
|
||||
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||
|
||||
resp = web.FileResponse(abs_path)
|
||||
resp.content_type = content_type
|
||||
resp.headers["Content-Disposition"] = cd
|
||||
return resp
|
||||
|
||||
|
||||
@ROUTES.put("/api/assets/{id}")
|
||||
async def update_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id_raw = request.match_info.get("id")
|
||||
@ -61,6 +95,24 @@ async def update_asset(request: web.Request) -> web.Response:
|
||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||
|
||||
|
||||
@ROUTES.delete("/api/assets/{id}")
|
||||
async def delete_asset(request: web.Request) -> web.Response:
|
||||
asset_info_id_raw = request.match_info.get("id")
|
||||
try:
|
||||
asset_info_id = int(asset_info_id_raw)
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid integer.")
|
||||
|
||||
try:
|
||||
deleted = await assets_manager.delete_asset_reference(asset_info_id=asset_info_id)
|
||||
except Exception:
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
if not deleted:
|
||||
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||
return web.Response(status=204)
|
||||
|
||||
|
||||
@ROUTES.get("/api/tags")
|
||||
async def get_tags(request: web.Request) -> web.Response:
|
||||
query_map = dict(request.rel_url.query)
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
import mimetypes
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
from typing import Optional, Sequence
|
||||
|
||||
from comfy.cli_args import args
|
||||
@ -17,6 +17,9 @@ from .database.services import (
|
||||
list_tags_with_usage,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
fetch_asset_info_and_asset,
|
||||
touch_asset_info_by_id,
|
||||
delete_asset_info_by_id,
|
||||
)
|
||||
from .api import schemas_out
|
||||
|
||||
@ -43,7 +46,7 @@ async def add_local_asset(tags: list[str], file_name: str, file_path: str) -> No
|
||||
|
||||
async with await create_session() as session:
|
||||
if await check_fs_asset_exists_quick(session, file_path=abs_path, size_bytes=size_bytes, mtime_ns=mtime_ns):
|
||||
await touch_asset_infos_by_fs_path(session, abs_path=abs_path, ts=datetime.now(timezone.utc))
|
||||
await touch_asset_infos_by_fs_path(session, abs_path=abs_path)
|
||||
await session.commit()
|
||||
return
|
||||
|
||||
@ -117,6 +120,40 @@ async def list_assets(
|
||||
)
|
||||
|
||||
|
||||
async def resolve_asset_content_for_download(
|
||||
*, asset_info_id: int
|
||||
) -> tuple[str, str, str]:
|
||||
"""
|
||||
Returns (abs_path, content_type, download_name) for the given AssetInfo id.
|
||||
Also touches last_access_time (only_if_newer).
|
||||
Raises:
|
||||
ValueError if AssetInfo not found
|
||||
NotImplementedError for unsupported backend
|
||||
FileNotFoundError if underlying file does not exist (fs backend)
|
||||
"""
|
||||
async with await create_session() as session:
|
||||
pair = await fetch_asset_info_and_asset(session, asset_info_id=asset_info_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
|
||||
if asset.storage_backend != "fs":
|
||||
# Future: support http/s3/gcs/...
|
||||
raise NotImplementedError(f"backend {asset.storage_backend!r} not supported yet")
|
||||
|
||||
abs_path = os.path.abspath(asset.storage_locator)
|
||||
if not os.path.exists(abs_path):
|
||||
raise FileNotFoundError(abs_path)
|
||||
|
||||
await touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
await session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
async def update_asset(
|
||||
*,
|
||||
asset_info_id: int,
|
||||
@ -148,6 +185,12 @@ async def update_asset(
|
||||
)
|
||||
|
||||
|
||||
async def delete_asset_reference(*, asset_info_id: int) -> bool:
|
||||
async with await create_session() as session:
|
||||
r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
await session.commit()
|
||||
return r
|
||||
|
||||
|
||||
async def list_tags(
|
||||
*,
|
||||
|
||||
@ -14,9 +14,10 @@ from sqlalchemy import (
|
||||
Numeric,
|
||||
Boolean,
|
||||
)
|
||||
from sqlalchemy.sql import func
|
||||
from sqlalchemy.orm import DeclarativeBase, Mapped, mapped_column, relationship, foreign
|
||||
|
||||
from .timeutil import utcnow
|
||||
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
@ -46,10 +47,10 @@ class Asset(Base):
|
||||
storage_backend: Mapped[str] = mapped_column(String(32), nullable=False, default="fs")
|
||||
storage_locator: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
infos: Mapped[list["AssetInfo"]] = relationship(
|
||||
@ -125,13 +126,13 @@ class AssetInfo(Base):
|
||||
preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
last_access_time: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=True), nullable=False, server_default=func.current_timestamp()
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
# Relationships
|
||||
@ -221,7 +222,9 @@ class AssetInfoTag(Base):
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_by: Mapped[str | None] = mapped_column(String(128))
|
||||
added_at: Mapped[datetime] = mapped_column(DateTime(timezone=True), nullable=False)
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
)
|
||||
|
||||
asset_info: Mapped["AssetInfo"] = relationship(back_populates="tag_links")
|
||||
tag: Mapped["Tag"] = relationship(back_populates="asset_info_links")
|
||||
|
||||
@ -1,7 +1,7 @@
|
||||
import os
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from datetime import datetime, timezone
|
||||
from datetime import datetime
|
||||
from decimal import Decimal
|
||||
from typing import Any, Sequence, Optional, Iterable
|
||||
|
||||
@ -12,6 +12,7 @@ from sqlalchemy.orm import contains_eager
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from .models import Asset, AssetInfo, AssetInfoTag, AssetLocatorState, Tag, AssetInfoMeta
|
||||
from .timeutil import utcnow
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
@ -93,7 +94,7 @@ async def ingest_fs_asset(
|
||||
}
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
datetime_now = datetime.now(timezone.utc)
|
||||
datetime_now = utcnow()
|
||||
|
||||
out = {
|
||||
"asset_created": False,
|
||||
@ -246,7 +247,7 @@ async def touch_asset_infos_by_fs_path(
|
||||
only_if_newer: bool = True,
|
||||
) -> int:
|
||||
locator = os.path.abspath(abs_path)
|
||||
ts = ts or datetime.now(timezone.utc)
|
||||
ts = ts or utcnow()
|
||||
|
||||
stmt = sa.update(AssetInfo).where(
|
||||
sa.exists(
|
||||
@ -274,13 +275,31 @@ async def touch_asset_infos_by_fs_path(
|
||||
return int(res.rowcount or 0)
|
||||
|
||||
|
||||
async def touch_asset_info_by_id(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
ts: Optional[datetime] = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> int:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
stmt = stmt.values(last_access_time=ts)
|
||||
res = await session.execute(stmt)
|
||||
return int(res.rowcount or 0)
|
||||
|
||||
|
||||
async def list_asset_infos_page(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
include_tags: Optional[Sequence[str]] = None,
|
||||
exclude_tags: Optional[Sequence[str]] = None,
|
||||
name_contains: Optional[str] = None,
|
||||
metadata_filter: Optional[dict] = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
@ -361,6 +380,19 @@ async def list_asset_infos_page(
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
async def fetch_asset_info_and_asset(session: AsyncSession, *, asset_info_id: int) -> Optional[tuple[AssetInfo, Asset]]:
|
||||
row = await session.execute(
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.hash == AssetInfo.asset_hash)
|
||||
.where(AssetInfo.id == asset_info_id)
|
||||
.limit(1)
|
||||
)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
|
||||
async def set_asset_info_tags(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
@ -374,7 +406,6 @@ async def set_asset_info_tags(
|
||||
Creates missing tag names as 'user'.
|
||||
"""
|
||||
desired = _normalize_tags(tags)
|
||||
now = datetime.now(timezone.utc)
|
||||
|
||||
# current links
|
||||
current = set(
|
||||
@ -389,7 +420,7 @@ async def set_asset_info_tags(
|
||||
if to_add:
|
||||
await _ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=now)
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_by=added_by, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
await session.flush()
|
||||
@ -447,17 +478,23 @@ async def update_asset_info_full(
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = datetime.now(timezone.utc)
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
async def delete_asset_info_by_id(session: AsyncSession, *, asset_info_id: int) -> bool:
|
||||
"""Delete the user-visible AssetInfo row. Cascades clear tags and metadata."""
|
||||
res = await session.execute(delete(AssetInfo).where(AssetInfo.id == asset_info_id))
|
||||
return bool(res.rowcount)
|
||||
|
||||
|
||||
async def replace_asset_info_metadata_projection(
|
||||
session: AsyncSession,
|
||||
*,
|
||||
asset_info_id: int,
|
||||
user_metadata: dict | None,
|
||||
user_metadata: Optional[dict],
|
||||
) -> None:
|
||||
"""Replaces the `assets_info.user_metadata` AND rebuild the projection rows in `asset_info_meta`."""
|
||||
info = await session.get(AssetInfo, asset_info_id)
|
||||
@ -465,7 +502,7 @@ async def replace_asset_info_metadata_projection(
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = datetime.now(timezone.utc)
|
||||
info.updated_at = utcnow()
|
||||
await session.flush()
|
||||
|
||||
await session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
@ -507,7 +544,7 @@ async def get_asset_tags(session: AsyncSession, *, asset_info_id: int) -> list[s
|
||||
async def list_tags_with_usage(
|
||||
session,
|
||||
*,
|
||||
prefix: str | None = None,
|
||||
prefix: Optional[str] = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
@ -611,7 +648,6 @@ async def add_tags_to_asset_info(
|
||||
already = [t for t in norm if t in existing]
|
||||
|
||||
if to_add:
|
||||
now = datetime.now(timezone.utc)
|
||||
# Make insert race-safe with a nested tx; ignore dup conflicts if any.
|
||||
async with session.begin_nested():
|
||||
session.add_all([
|
||||
@ -620,7 +656,7 @@ async def add_tags_to_asset_info(
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_by=added_by,
|
||||
added_at=now,
|
||||
added_at=utcnow(),
|
||||
) for t in to_add
|
||||
])
|
||||
try:
|
||||
@ -677,7 +713,7 @@ async def remove_tags_from_asset_info(
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def _normalize_tags(tags: Sequence[str] | None) -> list[str]:
|
||||
def _normalize_tags(tags: Optional[Sequence[str]]) -> list[str]:
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
|
||||
|
||||
@ -697,8 +733,8 @@ async def _ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_ty
|
||||
|
||||
def _apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None,
|
||||
exclude_tags: Sequence[str] | None,
|
||||
include_tags: Optional[Sequence[str]],
|
||||
exclude_tags: Optional[Sequence[str]],
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = _normalize_tags(include_tags)
|
||||
@ -724,7 +760,7 @@ def _apply_tag_filters(
|
||||
|
||||
def _apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None,
|
||||
metadata_filter: Optional[dict],
|
||||
) -> sa.sql.Select:
|
||||
"""Apply metadata filters using the projection table asset_info_meta.
|
||||
|
||||
|
||||
6
app/database/timeutil.py
Normal file
6
app/database/timeutil.py
Normal file
@ -0,0 +1,6 @@
|
||||
from datetime import datetime, timezone
|
||||
|
||||
|
||||
def utcnow() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
Loading…
Reference in New Issue
Block a user