ComfyUI/app/assets/database/queries/asset.py
Luke Mino-Altherr 32a6fcf7a8
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
fix: address code review feedback - round 2
- Reject path separators (/, \, os.sep) in tag components for defense-in-depth
- Add comment explaining double-relpath normalization trick
- Add _require_assets_feature_enabled decorator returning 503 when disabled
- Call asset_seeder.disable() when --enable-assets is not passed
- Add iter_chunks to bulk_update_needs_verify, bulk_update_is_missing,
  and delete_references_by_ids to respect SQLite bind param limits
- Fix CacheStateRow.size_bytes NULL coercion (0 -> None) to avoid
  false needs_verify flags on assets with unknown size
- Add PermissionError catch in delete_asset_tags route (403 vs 500)
- Add hash-is-None guard in delete_orphaned_seed_asset
- Validate from_asset_id in reassign_asset_references
- Initialize _prune_first in __init__, remove getattr workaround
- Cap error accumulation in _add_error to 200
- Remove confirmed dead code: seed_assets, compute_filename_for_asset,
  ALLOWED_ROOTS, AssetNotFoundError, SetTagsResult, update_enrichment_level,
  Asset.to_dict, AssetReference.to_dict, _AssetSeeder.enable

Amp-Thread-ID: https://ampcode.com/threads/T-019cb610-1b55-74b6-8dbb-381d73c387c0
Co-authored-by: Amp <amp@ampcode.com>
2026-03-03 16:35:55 -08:00

141 lines
3.7 KiB
Python

import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks
def asset_exists_by_hash(
session: Session,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True))
.select_from(Asset)
.where(Asset.hash == asset_hash)
.limit(1)
)
).first()
return row is not None
def get_asset_by_hash(
session: Session,
asset_hash: str,
) -> Asset | None:
return (
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
.scalars()
.first()
)
def upsert_asset(
session: Session,
asset_hash: str,
size_bytes: int,
mime_type: str | None = None,
) -> tuple[Asset, bool, bool]:
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
if mime_type:
vals["mime_type"] = mime_type
ins = (
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
.scalars()
.first()
)
if not asset:
raise RuntimeError("Asset row not found after upsert.")
updated = False
if not created:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
updated = True
return asset, created, updated
def bulk_insert_assets(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash."""
if not rows:
return
ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash])
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
session.execute(ins, chunk)
def get_existing_asset_ids(
session: Session,
asset_ids: list[str],
) -> set[str]:
"""Return the subset of asset_ids that exist in the database."""
if not asset_ids:
return set()
found: set[str] = set()
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
rows = session.execute(
select(Asset.id).where(Asset.id.in_(chunk))
).fetchall()
found.update(row[0] for row in rows)
return found
def update_asset_hash_and_mime(
session: Session,
asset_id: str,
asset_hash: str | None = None,
mime_type: str | None = None,
) -> bool:
"""Update asset hash and/or mime_type. Returns True if asset was found."""
asset = session.get(Asset, asset_id)
if not asset:
return False
if asset_hash is not None:
asset.hash = asset_hash
if mime_type is not None:
asset.mime_type = mime_type
return True
def reassign_asset_references(
session: Session,
from_asset_id: str,
to_asset_id: str,
reference_id: str,
) -> None:
"""Reassign a reference from one asset to another.
Used when merging a stub asset into an existing asset with the same hash.
"""
ref = session.get(AssetReference, reference_id)
if ref and ref.asset_id == from_asset_id:
ref.asset_id = to_asset_id
session.flush()