From a82577f64aaba1a6b871f9ddc57f1f9ea7531e32 Mon Sep 17 00:00:00 2001 From: bigcat88 Date: Sun, 24 Aug 2025 15:08:53 +0300 Subject: [PATCH] auto-creation of tags and fixed population DB when cloned asset is already present --- alembic_db/versions/0001_assets.py | 16 +-- app/assets_manager.py | 10 +- app/database/models.py | 6 +- app/database/services.py | 167 ++++++++++++++++------------- 4 files changed, 114 insertions(+), 85 deletions(-) diff --git a/alembic_db/versions/0001_assets.py b/alembic_db/versions/0001_assets.py index 47bb43dd8..cdda63fbe 100644 --- a/alembic_db/versions/0001_assets.py +++ b/alembic_db/versions/0001_assets.py @@ -18,7 +18,7 @@ def upgrade() -> None: # ASSETS: content identity (deduplicated by hash) op.create_table( "assets", - sa.Column("hash", sa.String(length=128), primary_key=True), + sa.Column("hash", sa.String(length=256), primary_key=True), sa.Column("size_bytes", sa.BigInteger(), nullable=False, server_default="0"), sa.Column("mime_type", sa.String(length=255), nullable=True), sa.Column("refcount", sa.BigInteger(), nullable=False, server_default="0"), @@ -36,14 +36,15 @@ def upgrade() -> None: op.create_table( "assets_info", sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True), - sa.Column("owner_id", sa.String(length=128), nullable=True), + sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""), sa.Column("name", sa.String(length=512), nullable=False), - sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False), - sa.Column("preview_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True), + sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False), + sa.Column("preview_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="SET NULL"), nullable=True), sa.Column("user_metadata", sa.JSON(), nullable=True), sa.Column("created_at", sa.DateTime(timezone=False), nullable=False), sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False), sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False), + sa.UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"), sqlite_autoincrement=True, ) op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"]) @@ -65,7 +66,7 @@ def upgrade() -> None: op.create_table( "asset_info_tags", sa.Column("asset_info_id", sa.BigInteger(), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False), - sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), + sa.Column("tag_name", sa.String(length=128), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False), sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"), sa.Column("added_by", sa.String(length=128), nullable=True), sa.Column("added_at", sa.DateTime(timezone=False), nullable=False), @@ -77,7 +78,7 @@ def upgrade() -> None: # ASSET_LOCATOR_STATE: 1:1 filesystem metadata(for fast integrity checking) for an Asset records op.create_table( "asset_locator_state", - sa.Column("asset_hash", sa.String(length=128), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True), + sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True), sa.Column("mtime_ns", sa.BigInteger(), nullable=True), sa.Column("etag", sa.String(length=256), nullable=True), sa.Column("last_modified", sa.String(length=128), nullable=True), @@ -112,6 +113,8 @@ def upgrade() -> None: [ # Core concept tags {"name": "models", "tag_type": "system"}, + {"name": "input", "tag_type": "system"}, + {"name": "output", "tag_type": "system"}, # Canonical single-word types {"name": "checkpoint", "tag_type": "system"}, @@ -150,6 +153,7 @@ def downgrade() -> None: op.drop_index("ix_tags_tag_type", table_name="tags") op.drop_table("tags") + op.drop_constraint("uq_assets_info_hash_owner_name", table_name="assets_info") op.drop_index("ix_assets_info_last_access_time", table_name="assets_info") op.drop_index("ix_assets_info_created_at", table_name="assets_info") op.drop_index("ix_assets_info_name", table_name="assets_info") diff --git a/app/assets_manager.py b/app/assets_manager.py index f92232a3d..cece14486 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -1,6 +1,7 @@ import mimetypes import os from typing import Optional, Sequence +from pathlib import Path from comfy.cli_args import args from comfy_api.internal import async_to_sync @@ -34,8 +35,13 @@ async def asset_exists(*, asset_hash: str) -> bool: def populate_db_with_asset(tags: list[str], file_name: str, file_path: str) -> None: if not args.disable_model_processing: + p = Path(file_name) + dir_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] async_to_sync.AsyncToSyncConverter.run_async_in_thread( - add_local_asset, tags=tags, file_name=file_name, file_path=file_path + add_local_asset, + tags=list(dict.fromkeys([*tags, *dir_parts])), + file_name=p.name, + file_path=file_path, ) @@ -114,7 +120,7 @@ async def list_assets( size=int(asset.size_bytes) if asset else None, mime_type=asset.mime_type if asset else None, tags=tags, - preview_url=f"/api/v1/assets/{info.id}/content", # TODO: implement actual content endpoint later + preview_url=f"/api/v1/assets/{info.id}/content", created_at=info.created_at, updated_at=info.updated_at, last_access_time=info.last_access_time, diff --git a/app/database/models.py b/app/database/models.py index 06e46815d..20b88ca68 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -7,6 +7,7 @@ from sqlalchemy import ( DateTime, ForeignKey, Index, + UniqueConstraint, JSON, String, Text, @@ -118,7 +119,7 @@ class AssetInfo(Base): __tablename__ = "assets_info" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - owner_id: Mapped[str | None] = mapped_column(String(128)) + owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") name: Mapped[str] = mapped_column(String(512), nullable=False) asset_hash: Mapped[str] = mapped_column( String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False @@ -169,6 +170,8 @@ class AssetInfo(Base): ) __table_args__ = ( + UniqueConstraint("asset_hash", "owner_id", "name", name="uq_assets_info_hash_owner_name"), + Index("ix_assets_info_owner_name", "owner_id", "name"), Index("ix_assets_info_owner_id", "owner_id"), Index("ix_assets_info_asset_hash", "asset_hash"), Index("ix_assets_info_name", "name"), @@ -186,7 +189,6 @@ class AssetInfo(Base): return f"" - class AssetInfoMeta(Base): __tablename__ = "asset_info_meta" diff --git a/app/database/services.py b/app/database/services.py index b916a2055..960788f9e 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -1,3 +1,4 @@ +import contextlib import os import logging from collections import defaultdict @@ -75,7 +76,7 @@ async def ingest_fs_asset( mtime_ns: int, mime_type: Optional[str] = None, info_name: Optional[str] = None, - owner_id: Optional[str] = None, + owner_id: str = "", preview_hash: Optional[str] = None, user_metadata: Optional[dict] = None, tags: Sequence[str] = (), @@ -94,7 +95,7 @@ async def ingest_fs_asset( - Create an AssetInfo (no refcount changes). - Link provided tags to that AssetInfo. * If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table. - * If False (default), silently skips unknown tags. + * If False (default), create unknown tags. Returns flags and ids: { @@ -103,8 +104,6 @@ async def ingest_fs_asset( "state_created": bool, "state_updated": bool, "asset_info_id": int | None, - "tags_added": list[str], - "tags_missing": list[str], # filled only when require_existing_tags=False } """ locator = os.path.abspath(abs_path) @@ -116,13 +115,11 @@ async def ingest_fs_asset( "state_created": False, "state_updated": False, "asset_info_id": None, - "tags_added": [], - "tags_missing": [], } # ---- Step 1: INSERT Asset or UPDATE size_bytes/updated_at if exists ---- - async with session.begin_nested() as sp1: - try: + with contextlib.suppress(IntegrityError): + async with session.begin_nested(): session.add( Asset( hash=asset_hash, @@ -137,27 +134,29 @@ async def ingest_fs_asset( ) await session.flush() out["asset_created"] = True - except IntegrityError: - await sp1.rollback() - # Already exists by hash -> update selected fields if different - existing = await session.get(Asset, asset_hash) - if existing is not None: - desired_size = int(size_bytes) - if existing.size_bytes != desired_size: - existing.size_bytes = desired_size - existing.updated_at = datetime_now - out["asset_updated"] = True - else: - # This should not occur. Log for visibility. - logging.error("Asset %s not found after conflict; skipping update.", asset_hash) - except Exception: - await sp1.rollback() - logging.exception("Unexpected error inserting Asset (hash=%s, locator=%s)", asset_hash, locator) - raise + + if not out["asset_created"]: + existing = await session.get(Asset, asset_hash) + if existing is not None: + changed = False + if existing.size_bytes != size_bytes: + existing.size_bytes = size_bytes + changed = True + if mime_type and existing.mime_type != mime_type: + existing.mime_type = mime_type + changed = True + if existing.storage_locator != locator: + existing.storage_locator = locator + changed = True + if changed: + existing.updated_at = datetime_now + out["asset_updated"] = True + else: + logging.error("Asset %s not found after PK conflict; skipping update.", asset_hash) # ---- Step 2: INSERT/UPDATE AssetLocatorState (mtime_ns) ---- - async with session.begin_nested() as sp2: - try: + with contextlib.suppress(IntegrityError): + async with session.begin_nested(): session.add( AssetLocatorState( asset_hash=asset_hash, @@ -166,26 +165,22 @@ async def ingest_fs_asset( ) await session.flush() out["state_created"] = True - except IntegrityError: - await sp2.rollback() - state = await session.get(AssetLocatorState, asset_hash) - if state is not None: - desired_mtime = int(mtime_ns) - if state.mtime_ns != desired_mtime: - state.mtime_ns = desired_mtime - out["state_updated"] = True - else: - logging.debug("Locator state missing for %s after conflict; skipping update.", asset_hash) - except Exception: - await sp2.rollback() - logging.exception("Unexpected error inserting AssetLocatorState (hash=%s)", asset_hash) - raise + + if not out["state_created"]: + state = await session.get(AssetLocatorState, asset_hash) + if state is not None: + desired_mtime = int(mtime_ns) + if state.mtime_ns != desired_mtime: + state.mtime_ns = desired_mtime + out["state_updated"] = True + else: + logging.error("Locator state missing for %s after conflict; skipping update.", asset_hash) # ---- Optional: AssetInfo + tag links ---- if info_name: - # 2a) Create AssetInfo (no refcount bump) - async with session.begin_nested() as sp3: - try: + # 2a) Upsert AssetInfo idempotently on (asset_hash, owner_id, name) + with contextlib.suppress(IntegrityError): + async with session.begin_nested(): info = AssetInfo( owner_id=owner_id, name=info_name, @@ -198,16 +193,35 @@ async def ingest_fs_asset( session.add(info) await session.flush() # get info.id out["asset_info_id"] = info.id - except Exception: - await sp3.rollback() - logging.exception( - "Unexpected error inserting AssetInfo (hash=%s, name=%s)", asset_hash, info_name + + existing_info = ( + await session.execute( + select(AssetInfo) + .where( + AssetInfo.asset_hash == asset_hash, + AssetInfo.name == info_name, + (AssetInfo.owner_id == owner_id), ) - raise + .limit(1) + ) + ).unique().scalar_one_or_none() + if not existing_info: + raise RuntimeError("Failed to update or insert AssetInfo.") + + if preview_hash is not None and existing_info.preview_hash != preview_hash: + existing_info.preview_hash = preview_hash + existing_info.updated_at = datetime_now + if existing_info.last_access_time < datetime_now: + existing_info.last_access_time = datetime_now + await session.flush() + out["asset_info_id"] = existing_info.id # 2b) Link tags (if any). We DO NOT create new Tag rows here by default. norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] if norm and out["asset_info_id"] is not None: + if not require_existing_tags: + await _ensure_tags_exist(session, norm, tag_type="user") + # Which tags exist? existing_tag_names = set( name for (name,) in (await session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() @@ -240,8 +254,6 @@ async def ingest_fs_asset( ] ) await session.flush() - out["tags_added"] = to_add - out["tags_missing"] = missing # 2c) Rebuild metadata projection if provided if user_metadata is not None and out["asset_info_id"] is not None: @@ -420,7 +432,7 @@ async def create_asset_info_for_existing_asset( """Create a new AssetInfo referencing an existing Asset (no content write).""" now = utcnow() info = AssetInfo( - owner_id=None, + owner_id="", name=name, asset_hash=asset_hash, preview_hash=None, @@ -688,39 +700,44 @@ async def add_tags_to_asset_info( if create_if_missing: await _ensure_tags_exist(session, norm, tag_type="user") - # Current links - existing = { - tname - for (tname,) in ( + # Snapshot current links + current = { + tag_name + for (tag_name,) in ( await session.execute( sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) ) ).all() } - to_add = [t for t in norm if t not in existing] - already = [t for t in norm if t in existing] + want = set(norm) + to_add = sorted(want - current) if to_add: - # Make insert race-safe with a nested tx; ignore dup conflicts if any. - async with session.begin_nested(): - session.add_all([ - AssetInfoTag( - asset_info_id=asset_info_id, - tag_name=t, - origin=origin, - added_by=added_by, - added_at=utcnow(), - ) for t in to_add - ]) + async with session.begin_nested() as nested: try: + session.add_all( + [ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_by=added_by, + added_at=utcnow(), + ) + for t in to_add + ] + ) await session.flush() except IntegrityError: - # Another writer linked the same tag at the same time -> ok, treat as already present. - await session.rollback() + await nested.rollback() - total = await get_asset_tags(session, asset_info_id=asset_info_id) - return {"added": sorted(set(to_add)), "already_present": sorted(set(already)), "total_tags": total} + after = set(await get_asset_tags(session, asset_info_id=asset_info_id)) + return { + "added": sorted(((after - current) & want)), + "already_present": sorted(want & current), + "total_tags": sorted(after), + } async def remove_tags_from_asset_info( @@ -742,8 +759,8 @@ async def remove_tags_from_asset_info( return {"removed": [], "not_present": [], "total_tags": total} existing = { - tname - for (tname,) in ( + tag_name + for (tag_name,) in ( await session.execute( sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) )