add support for assets duplicates

This commit is contained in:
bigcat88 2025-09-06 19:22:51 +03:00
parent 789a62ce35
commit 2d9be462d3
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
6 changed files with 116 additions and 62 deletions

View File

@ -1,5 +1,4 @@
# File: /alembic_db/versions/0001_assets.py
"""initial assets schema + per-asset state cache
"""initial assets schema
Revision ID: 0001_assets
Revises:
@ -69,15 +68,18 @@ def upgrade() -> None:
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
# ASSET_CACHE_STATE: 1:1 local cache metadata for an Asset
# ASSET_CACHE_STATE: N:1 local cache metadata rows per Asset
op.create_table(
"asset_cache_state",
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True),
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_hash", sa.String(length=256), sa.ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False), # absolute local path to cached file
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_hash", "asset_cache_state", ["asset_hash"])
# ASSET_INFO_META: typed KV projection of user_metadata for filtering/sorting
op.create_table(
@ -144,7 +146,7 @@ def upgrade() -> None:
{"name": "photomaker", "tag_type": "system"},
{"name": "classifiers", "tag_type": "system"},
# Extra basic tags (used for vae_approx, ...)
# Extra basic tags
{"name": "encoder", "tag_type": "system"},
{"name": "decoder", "tag_type": "system"},
],
@ -162,6 +164,7 @@ def downgrade() -> None:
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_cache_state_asset_hash", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")

View File

@ -147,7 +147,7 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
return AssetInfo.owner_id.in_(["", owner_id])
def compute_model_relative_filename(file_path: str) -> str | None:
def compute_model_relative_filename(file_path: str) -> Optional[str]:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:

View File

@ -8,7 +8,7 @@ import aiohttp
from .storage.hashing import blake3_hash_sync
from .database.db import create_session
from .database.services import ingest_fs_asset, get_cache_state_by_asset_hash
from .database.services import ingest_fs_asset, list_cache_states_by_asset_hash
from .resolvers import resolve_asset
from ._assets_helpers import resolve_destination_from_tags, ensure_within_base
@ -26,20 +26,25 @@ async def ensure_asset_cached(
tags_hint: Optional[list[str]] = None,
) -> str:
"""
Ensure there is a verified local file for `asset_hash` in the correct Comfy folder.
Policy:
- Resolver must provide valid tags (root and, for models, category).
- If target path already exists:
* if hash matches -> reuse & ingest
* else -> remove and overwrite with the correct content
Ensure there is a verified local file for asset_hash in the correct Comfy folder.
Fast path:
- If any cache_state row has a file_path that exists, return it immediately.
Preference order is the oldest ID first for stability.
Slow path:
- Resolve remote location + placement tags.
- Download to the correct folder, verify hash, move into place.
- Ingest identity + cache state so future fast passes can skip hashing.
"""
lock = _FETCH_LOCKS.setdefault(asset_hash, asyncio.Lock())
async with lock:
# 1) If we already have a state -> trust the path
# 1) If we already have any cache_state path present on disk, use it (oldest-first)
async with await create_session() as sess:
state = await get_cache_state_by_asset_hash(sess, asset_hash=asset_hash)
if state and os.path.isfile(state.file_path):
return state.file_path
states = await list_cache_states_by_asset_hash(sess, asset_hash=asset_hash)
for s in states:
if s and s.file_path and os.path.isfile(s.file_path):
return s.file_path
# 2) Resolve remote location + placement hints (must include valid tags)
res = await resolve_asset(asset_hash)
@ -107,7 +112,7 @@ async def ensure_asset_cached(
finally:
raise ValueError(f"Hash mismatch: expected {asset_hash}, got {canonical}")
# 7) Atomically move into place (we already removed an invalid file if it existed)
# 7) Atomically move into place
if os.path.exists(final_path):
os.remove(final_path)
os.replace(tmp_path, final_path)

View File

@ -1,4 +1,5 @@
import asyncio
import contextlib
import logging
import os
import uuid
@ -106,7 +107,7 @@ async def schedule_scans(roots: Sequence[str]) -> schemas_out.AssetScanStatusRes
async def fast_reconcile_and_kickoff(
roots: Sequence[str] | None = None,
roots: Optional[Sequence[str]] = None,
*,
progress_cb: Optional[Callable[[str, str, int, bool, dict], None]] = None,
) -> schemas_out.AssetScanStatusResponse:
@ -216,18 +217,18 @@ async def _fast_reconcile_into_queue(
"""
if root == "models":
files = _collect_models_files()
preset_discovered = len(files)
preset_discovered = _count_nonzero_in_list(files)
files_iter = asyncio.Queue()
for p in files:
await files_iter.put(p)
await files_iter.put(None) # sentinel for our local draining loop
elif root == "input":
base = folder_paths.get_input_directory()
preset_discovered = _count_files_in_tree(os.path.abspath(base))
preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True)
files_iter = await _queue_tree_files(base)
elif root == "output":
base = folder_paths.get_output_directory()
preset_discovered = _count_files_in_tree(os.path.abspath(base))
preset_discovered = _count_files_in_tree(os.path.abspath(base), only_nonzero=True)
files_iter = await _queue_tree_files(base)
else:
raise RuntimeError(f"Unsupported root: {root}")
@ -378,26 +379,41 @@ def _collect_models_files() -> list[str]:
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
try:
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
except Exception:
pass
if allowed:
out.append(abs_path)
return out
def _count_files_in_tree(base_abs: str) -> int:
def _count_files_in_tree(base_abs: str, *, only_nonzero: bool = False) -> int:
if not os.path.isdir(base_abs):
return 0
total = 0
for _dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
total += len(filenames)
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
if not only_nonzero:
total += len(filenames)
else:
for name in filenames:
with contextlib.suppress(OSError):
st = os.stat(os.path.join(dirpath, name), follow_symlinks=True)
if st.st_size:
total += 1
return total
def _count_nonzero_in_list(paths: list[str]) -> int:
cnt = 0
for p in paths:
with contextlib.suppress(OSError):
st = os.stat(p, follow_symlinks=True)
if st.st_size:
cnt += 1
return cnt
async def _queue_tree_files(base_dir: str) -> asyncio.Queue:
"""
Walk base_dir in a worker thread and return a queue prefilled with all paths,
@ -455,7 +471,7 @@ def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e:
e["discovered"],
e["queued"],
)
elif e.get("checked", 0) % 500 == 0: # do not spam with fast progress
elif e.get("checked", 0) % 1000 == 0: # do not spam with fast progress
logging.info(
"[assets][%s] fast progress: processed=%s/%s",
root,
@ -464,12 +480,13 @@ def _console_cb(root: str, phase: str, total_processed: int, finished: bool, e:
)
elif phase == "slow":
if finished:
logging.info(
"[assets][%s] slow done: %s/%s",
root,
e.get("slow_queue_finished", 0),
e.get("slow_queue_total", 0),
)
if e.get("slow_queue_finished", 0) or e.get("slow_queue_total", 0):
logging.info(
"[assets][%s] slow done: %s/%s",
root,
e.get("slow_queue_finished", 0),
e.get("slow_queue_total", 0),
)
elif e.get('slow_queue_finished', 0) % 3 == 0:
logging.info(
"[assets][%s] slow progress: %s/%s",

View File

@ -1,3 +1,5 @@
from __future__ import annotations
from datetime import datetime
from typing import Any, Optional
import uuid
@ -66,9 +68,8 @@ class Asset(Base):
viewonly=True,
)
cache_state: Mapped["AssetCacheState | None"] = relationship(
cache_states: Mapped[list["AssetCacheState"]] = relationship(
back_populates="asset",
uselist=False,
cascade="all, delete-orphan",
passive_deletes=True,
)
@ -93,24 +94,25 @@ class Asset(Base):
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
asset_hash: Mapped[str] = mapped_column(
String(256), ForeignKey("assets.hash", ondelete="CASCADE"), primary_key=True
)
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_hash: Mapped[str] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
asset: Mapped["Asset"] = relationship(back_populates="cache_state", uselist=False)
asset: Mapped["Asset"] = relationship(back_populates="cache_states")
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_hash", "asset_hash"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<AssetCacheState hash={self.asset_hash[:12]} path={self.file_path!r}>"
return f"<AssetCacheState id={self.id} hash={self.asset_hash[:12]} path={self.file_path!r}>"
class AssetLocation(Base):

View File

@ -4,7 +4,7 @@ import logging
from collections import defaultdict
from datetime import datetime
from decimal import Decimal
from typing import Any, Sequence, Optional, Iterable
from typing import Any, Sequence, Optional, Iterable, Union
import sqlalchemy as sa
from sqlalchemy.ext.asyncio import AsyncSession
@ -82,14 +82,14 @@ async def ingest_fs_asset(
require_existing_tags: bool = False,
) -> dict:
"""
Upsert Asset identity row + cache state pointing at local file.
Upsert Asset identity row + cache state(s) pointing at local file.
Always:
- Insert Asset if missing;
- Insert AssetCacheState if missing; else update mtime_ns if different.
- Insert AssetCacheState if missing; else update mtime_ns and asset_hash if different.
Optionally (when info_name is provided):
- Create an AssetInfo.
- Create or update an AssetInfo on (asset_hash, owner_id, name).
- Link provided tags to that AssetInfo.
* If the require_existing_tags=True, raises ValueError if any tag does not exist in `tags` table.
* If False (default), create unknown tags.
@ -157,11 +157,16 @@ async def ingest_fs_asset(
out["state_created"] = True
if not out["state_created"]:
state = await session.get(AssetCacheState, asset_hash)
# most likely a unique(file_path) conflict; update that row
state = (
await session.execute(
select(AssetCacheState).where(AssetCacheState.file_path == locator).limit(1)
)
).scalars().first()
if state is not None:
changed = False
if state.file_path != locator:
state.file_path = locator
if state.asset_hash != asset_hash:
state.asset_hash = asset_hash
changed = True
if state.mtime_ns != int(mtime_ns):
state.mtime_ns = int(mtime_ns)
@ -260,7 +265,15 @@ async def ingest_fs_asset(
# )
# start of adding metadata["filename"]
if out["asset_info_id"] is not None:
computed_filename = compute_model_relative_filename(abs_path)
primary_path = (
await session.execute(
select(AssetCacheState.file_path)
.where(AssetCacheState.asset_hash == asset_hash)
.order_by(AssetCacheState.id.asc())
.limit(1)
)
).scalars().first()
computed_filename = compute_model_relative_filename(primary_path) if primary_path else None
# Start from current metadata on this AssetInfo, if any
current_meta = existing_info.user_metadata or {}
@ -366,7 +379,6 @@ async def list_asset_infos_page(
base = _apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter)
# Sort
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
@ -381,7 +393,6 @@ async def list_asset_infos_page(
base = base.order_by(sort_exp).limit(limit).offset(offset)
# Total count (same filters, no ordering/limit/offset)
count_stmt = (
select(func.count())
.select_from(AssetInfo)
@ -395,10 +406,9 @@ async def list_asset_infos_page(
total = int((await session.execute(count_stmt)).scalar_one() or 0)
# Fetch rows
infos = (await session.execute(base)).scalars().unique().all()
# Collect tags in bulk (single query)
# Collect tags in bulk
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
@ -470,12 +480,33 @@ async def fetch_asset_info_asset_and_tags(
async def get_cache_state_by_asset_hash(session: AsyncSession, *, asset_hash: str) -> Optional[AssetCacheState]:
return await session.get(AssetCacheState, asset_hash)
"""Return the oldest cache row for this asset."""
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_hash == asset_hash)
.order_by(AssetCacheState.id.asc())
.limit(1)
)
).scalars().first()
async def list_cache_states_by_asset_hash(
session: AsyncSession, *, asset_hash: str
) -> Union[list[AssetCacheState], Sequence[AssetCacheState]]:
"""Return all cache rows for this asset ordered by oldest first."""
return (
await session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_hash == asset_hash)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
async def list_asset_locations(
session: AsyncSession, *, asset_hash: str, provider: Optional[str] = None
) -> list[AssetLocation] | Sequence[AssetLocation]:
) -> Union[list[AssetLocation], Sequence[AssetLocation]]:
stmt = select(AssetLocation).where(AssetLocation.asset_hash == asset_hash)
if provider:
stmt = stmt.where(AssetLocation.provider == provider)
@ -815,7 +846,6 @@ async def list_tags_with_usage(
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
# Ordering
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else: # default "count_desc"
@ -990,6 +1020,7 @@ def _apply_tag_filters(
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: Optional[dict],
@ -1050,7 +1081,7 @@ def _apply_metadata_filter(
for k, v in metadata_filter.items():
if isinstance(v, list):
# ANY-of (exists for any element)
ors = [ _exists_clause_for_value(k, elem) for elem in v ]
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
@ -1079,12 +1110,10 @@ def _project_kv(key: str, value: Any) -> list[dict]:
"""
rows: list[dict] = []
# None
if value is None:
rows.append({"key": key, "ordinal": 0, "val_json": None})
return rows
# Scalars
if _is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
@ -1099,9 +1128,7 @@ def _project_kv(key: str, value: Any) -> list[dict]:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
# Lists
if isinstance(value, list):
# list of scalars?
if all(_is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None: