mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-03-06 17:57:40 +08:00
Reduce duplication across assets module
- Extract validate_blake3_hash() into helpers.py, used by upload, schemas, routes - Extract get_reference_with_owner_check() into queries, used by 4 service functions - Extract build_prefix_like_conditions() into queries/common.py, used by 3 queries - Replace 3 inlined tag queries with get_reference_tags() calls - Consolidate AddTagsDict/RemoveTagsDict TypedDicts into AddTagsResult/RemoveTagsResult dataclasses, eliminating manual field copying in tagging.py - Make iter_row_chunks delegate to iter_chunks - Inline trivial compute_filename_for_reference wrapper (unused session param) - Remove mark_assets_missing_outside_prefixes pass-through in bulk_ingest.py - Clean up unused imports (os, time, dependencies_available) - Disable assets routes on DB init failure in main.py Amp-Thread-ID: https://ampcode.com/threads/T-019cb649-dd4e-71ff-9a0e-ae517365207b Co-authored-by: Amp <amp@ampcode.com>
This commit is contained in:
parent
32a6fcf7a8
commit
67c4f79c22
@ -17,6 +17,7 @@ from app.assets.api.schemas_in import (
|
|||||||
AssetValidationError,
|
AssetValidationError,
|
||||||
UploadError,
|
UploadError,
|
||||||
)
|
)
|
||||||
|
from app.assets.helpers import validate_blake3_hash
|
||||||
from app.assets.api.upload import (
|
from app.assets.api.upload import (
|
||||||
delete_temp_file_if_exists,
|
delete_temp_file_if_exists,
|
||||||
parse_multipart_upload,
|
parse_multipart_upload,
|
||||||
@ -89,6 +90,12 @@ def register_assets_routes(
|
|||||||
app.add_routes(ROUTES)
|
app.add_routes(ROUTES)
|
||||||
|
|
||||||
|
|
||||||
|
def disable_assets_routes() -> None:
|
||||||
|
"""Disable asset routes at runtime (e.g. after DB init failure)."""
|
||||||
|
global _ASSETS_ENABLED
|
||||||
|
_ASSETS_ENABLED = False
|
||||||
|
|
||||||
|
|
||||||
def _build_error_response(
|
def _build_error_response(
|
||||||
status: int, code: str, message: str, details: dict | None = None
|
status: int, code: str, message: str, details: dict | None = None
|
||||||
) -> web.Response:
|
) -> web.Response:
|
||||||
@ -116,16 +123,9 @@ def _validate_sort_field(requested: str | None) -> str:
|
|||||||
@_require_assets_feature_enabled
|
@_require_assets_feature_enabled
|
||||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||||
hash_str = request.match_info.get("hash", "").strip().lower()
|
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||||
if not hash_str or ":" not in hash_str:
|
try:
|
||||||
return _build_error_response(
|
hash_str = validate_blake3_hash(hash_str)
|
||||||
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
|
except ValueError:
|
||||||
)
|
|
||||||
algo, digest = hash_str.split(":", 1)
|
|
||||||
if (
|
|
||||||
algo != "blake3"
|
|
||||||
or not digest
|
|
||||||
or any(c for c in digest if c not in "0123456789abcdef")
|
|
||||||
):
|
|
||||||
return _build_error_response(
|
return _build_error_response(
|
||||||
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
|
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
|
||||||
)
|
)
|
||||||
|
|||||||
@ -2,6 +2,7 @@ import json
|
|||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
from app.assets.helpers import validate_blake3_hash
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
ConfigDict,
|
ConfigDict,
|
||||||
@ -116,15 +117,7 @@ class CreateFromHashBody(BaseModel):
|
|||||||
@field_validator("hash")
|
@field_validator("hash")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _require_blake3(cls, v):
|
def _require_blake3(cls, v):
|
||||||
s = (v or "").strip().lower()
|
return validate_blake3_hash(v or "")
|
||||||
if ":" not in s:
|
|
||||||
raise ValueError("hash must be 'blake3:<hex>'")
|
|
||||||
algo, digest = s.split(":", 1)
|
|
||||||
if algo != "blake3":
|
|
||||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
|
||||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
|
||||||
raise ValueError("hash digest must be lowercase hex")
|
|
||||||
return s
|
|
||||||
|
|
||||||
@field_validator("tags", mode="before")
|
@field_validator("tags", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -214,17 +207,10 @@ class UploadAssetSpec(BaseModel):
|
|||||||
def _parse_hash(cls, v):
|
def _parse_hash(cls, v):
|
||||||
if v is None:
|
if v is None:
|
||||||
return None
|
return None
|
||||||
s = str(v).strip().lower()
|
s = str(v).strip()
|
||||||
if not s:
|
if not s:
|
||||||
return None
|
return None
|
||||||
if ":" not in s:
|
return validate_blake3_hash(s)
|
||||||
raise ValueError("hash must be 'blake3:<hex>'")
|
|
||||||
algo, digest = s.split(":", 1)
|
|
||||||
if algo != "blake3":
|
|
||||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
|
||||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
|
||||||
raise ValueError("hash digest must be lowercase hex")
|
|
||||||
return f"{algo}:{digest}"
|
|
||||||
|
|
||||||
@field_validator("tags", mode="before")
|
@field_validator("tags", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@ -7,27 +7,18 @@ from aiohttp import web
|
|||||||
|
|
||||||
import folder_paths
|
import folder_paths
|
||||||
from app.assets.api.schemas_in import ParsedUpload, UploadError
|
from app.assets.api.schemas_in import ParsedUpload, UploadError
|
||||||
|
from app.assets.helpers import validate_blake3_hash
|
||||||
|
|
||||||
|
|
||||||
def normalize_and_validate_hash(s: str) -> str:
|
def normalize_and_validate_hash(s: str) -> str:
|
||||||
"""
|
"""Validate and normalize a hash string.
|
||||||
Validate and normalize a hash string.
|
|
||||||
|
|
||||||
Returns canonical 'blake3:<hex>' or raises UploadError.
|
Returns canonical 'blake3:<hex>' or raises UploadError.
|
||||||
"""
|
"""
|
||||||
s = s.strip().lower()
|
try:
|
||||||
if not s:
|
return validate_blake3_hash(s)
|
||||||
|
except ValueError:
|
||||||
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||||
if ":" not in s:
|
|
||||||
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
|
||||||
algo, digest = s.split(":", 1)
|
|
||||||
if (
|
|
||||||
algo != "blake3"
|
|
||||||
or len(digest) != 64
|
|
||||||
or any(c for c in digest if c not in "0123456789abcdef")
|
|
||||||
):
|
|
||||||
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
|
||||||
return f"{algo}:{digest}"
|
|
||||||
|
|
||||||
|
|
||||||
async def parse_multipart_upload(
|
async def parse_multipart_upload(
|
||||||
|
|||||||
@ -24,6 +24,7 @@ from app.assets.database.queries.asset_reference import (
|
|||||||
get_or_create_reference,
|
get_or_create_reference,
|
||||||
get_reference_by_file_path,
|
get_reference_by_file_path,
|
||||||
get_reference_by_id,
|
get_reference_by_id,
|
||||||
|
get_reference_with_owner_check,
|
||||||
get_reference_ids_by_ids,
|
get_reference_ids_by_ids,
|
||||||
get_references_by_paths_and_asset_ids,
|
get_references_by_paths_and_asset_ids,
|
||||||
get_references_for_prefixes,
|
get_references_for_prefixes,
|
||||||
@ -44,9 +45,9 @@ from app.assets.database.queries.asset_reference import (
|
|||||||
upsert_reference,
|
upsert_reference,
|
||||||
)
|
)
|
||||||
from app.assets.database.queries.tags import (
|
from app.assets.database.queries.tags import (
|
||||||
AddTagsDict,
|
AddTagsResult,
|
||||||
RemoveTagsDict,
|
RemoveTagsResult,
|
||||||
SetTagsDict,
|
SetTagsResult,
|
||||||
add_missing_tag_for_asset_id,
|
add_missing_tag_for_asset_id,
|
||||||
add_tags_to_reference,
|
add_tags_to_reference,
|
||||||
bulk_insert_tags_and_meta,
|
bulk_insert_tags_and_meta,
|
||||||
@ -60,10 +61,10 @@ from app.assets.database.queries.tags import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
"AddTagsDict",
|
"AddTagsResult",
|
||||||
"CacheStateRow",
|
"CacheStateRow",
|
||||||
"RemoveTagsDict",
|
"RemoveTagsResult",
|
||||||
"SetTagsDict",
|
"SetTagsResult",
|
||||||
"UnenrichedReferenceRow",
|
"UnenrichedReferenceRow",
|
||||||
"add_missing_tag_for_asset_id",
|
"add_missing_tag_for_asset_id",
|
||||||
"add_tags_to_reference",
|
"add_tags_to_reference",
|
||||||
@ -87,6 +88,7 @@ __all__ = [
|
|||||||
"get_or_create_reference",
|
"get_or_create_reference",
|
||||||
"get_reference_by_file_path",
|
"get_reference_by_file_path",
|
||||||
"get_reference_by_id",
|
"get_reference_by_id",
|
||||||
|
"get_reference_with_owner_check",
|
||||||
"get_reference_ids_by_ids",
|
"get_reference_ids_by_ids",
|
||||||
"get_reference_tags",
|
"get_reference_tags",
|
||||||
"get_references_by_paths_and_asset_ids",
|
"get_references_by_paths_and_asset_ids",
|
||||||
|
|||||||
@ -4,7 +4,6 @@ This module replaces the separate asset_info.py and cache_state.py query modules
|
|||||||
providing a unified interface for the merged asset_references table.
|
providing a unified interface for the merged asset_references table.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from decimal import Decimal
|
from decimal import Decimal
|
||||||
@ -25,6 +24,7 @@ from app.assets.database.models import (
|
|||||||
)
|
)
|
||||||
from app.assets.database.queries.common import (
|
from app.assets.database.queries.common import (
|
||||||
MAX_BIND_PARAMS,
|
MAX_BIND_PARAMS,
|
||||||
|
build_prefix_like_conditions,
|
||||||
build_visible_owner_clause,
|
build_visible_owner_clause,
|
||||||
calculate_rows_per_statement,
|
calculate_rows_per_statement,
|
||||||
iter_chunks,
|
iter_chunks,
|
||||||
@ -165,6 +165,25 @@ def get_reference_by_id(
|
|||||||
return session.get(AssetReference, reference_id)
|
return session.get(AssetReference, reference_id)
|
||||||
|
|
||||||
|
|
||||||
|
def get_reference_with_owner_check(
|
||||||
|
session: Session,
|
||||||
|
reference_id: str,
|
||||||
|
owner_id: str,
|
||||||
|
) -> AssetReference:
|
||||||
|
"""Fetch a reference and verify ownership.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if reference not found
|
||||||
|
PermissionError: if owner_id doesn't match
|
||||||
|
"""
|
||||||
|
ref = get_reference_by_id(session, reference_id=reference_id)
|
||||||
|
if not ref:
|
||||||
|
raise ValueError(f"AssetReference {reference_id} not found")
|
||||||
|
if ref.owner_id and ref.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
return ref
|
||||||
|
|
||||||
|
|
||||||
def get_reference_by_file_path(
|
def get_reference_by_file_path(
|
||||||
session: Session,
|
session: Session,
|
||||||
file_path: str,
|
file_path: str,
|
||||||
@ -636,12 +655,8 @@ def mark_references_missing_outside_prefixes(
|
|||||||
if not valid_prefixes:
|
if not valid_prefixes:
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
def make_prefix_condition(prefix: str):
|
conds = build_prefix_like_conditions(valid_prefixes)
|
||||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
matches_valid_prefix = sa.or_(*conds)
|
||||||
escaped, esc = escape_sql_like_string(base)
|
|
||||||
return AssetReference.file_path.like(escaped + "%", escape=esc)
|
|
||||||
|
|
||||||
matches_valid_prefix = sa.or_(*[make_prefix_condition(p) for p in valid_prefixes])
|
|
||||||
result = session.execute(
|
result = session.execute(
|
||||||
sa.update(AssetReference)
|
sa.update(AssetReference)
|
||||||
.where(AssetReference.file_path.isnot(None))
|
.where(AssetReference.file_path.isnot(None))
|
||||||
@ -729,13 +744,7 @@ def get_references_for_prefixes(
|
|||||||
if not prefixes:
|
if not prefixes:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
conds = []
|
conds = build_prefix_like_conditions(prefixes)
|
||||||
for p in prefixes:
|
|
||||||
base = os.path.abspath(p)
|
|
||||||
if not base.endswith(os.sep):
|
|
||||||
base += os.sep
|
|
||||||
escaped, esc = escape_sql_like_string(base)
|
|
||||||
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
sa.select(
|
sa.select(
|
||||||
@ -875,13 +884,7 @@ def get_unenriched_references(
|
|||||||
if not prefixes:
|
if not prefixes:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
conds = []
|
conds = build_prefix_like_conditions(prefixes)
|
||||||
for p in prefixes:
|
|
||||||
base = os.path.abspath(p)
|
|
||||||
if not base.endswith(os.sep):
|
|
||||||
base += os.sep
|
|
||||||
escaped, esc = escape_sql_like_string(base)
|
|
||||||
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
|
||||||
|
|
||||||
query = (
|
query = (
|
||||||
sa.select(
|
sa.select(
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
"""Shared utilities for database query modules."""
|
"""Shared utilities for database query modules."""
|
||||||
|
|
||||||
|
import os
|
||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
|
||||||
from app.assets.database.models import AssetReference
|
from app.assets.database.models import AssetReference
|
||||||
|
from app.assets.helpers import escape_sql_like_string
|
||||||
|
|
||||||
MAX_BIND_PARAMS = 800
|
MAX_BIND_PARAMS = 800
|
||||||
|
|
||||||
@ -24,9 +26,7 @@ def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]
|
|||||||
"""Yield chunks of rows sized to fit within bind param limits."""
|
"""Yield chunks of rows sized to fit within bind param limits."""
|
||||||
if not rows:
|
if not rows:
|
||||||
return
|
return
|
||||||
rows_per_stmt = calculate_rows_per_statement(cols_per_row)
|
yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row))
|
||||||
for i in range(0, len(rows), rows_per_stmt):
|
|
||||||
yield rows[i : i + rows_per_stmt]
|
|
||||||
|
|
||||||
|
|
||||||
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||||
@ -38,3 +38,17 @@ def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
|||||||
if owner_id == "":
|
if owner_id == "":
|
||||||
return AssetReference.owner_id == ""
|
return AssetReference.owner_id == ""
|
||||||
return AssetReference.owner_id.in_(["", owner_id])
|
return AssetReference.owner_id.in_(["", owner_id])
|
||||||
|
|
||||||
|
|
||||||
|
def build_prefix_like_conditions(
|
||||||
|
prefixes: list[str],
|
||||||
|
) -> list[sa.sql.ColumnElement]:
|
||||||
|
"""Build LIKE conditions for matching file paths under directory prefixes."""
|
||||||
|
conds = []
|
||||||
|
for p in prefixes:
|
||||||
|
base = os.path.abspath(p)
|
||||||
|
if not base.endswith(os.sep):
|
||||||
|
base += os.sep
|
||||||
|
escaped, esc = escape_sql_like_string(base)
|
||||||
|
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
||||||
|
return conds
|
||||||
|
|||||||
@ -1,4 +1,5 @@
|
|||||||
from typing import Iterable, Sequence, TypedDict
|
from dataclasses import dataclass
|
||||||
|
from typing import Iterable, Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import delete, func, select
|
from sqlalchemy import delete, func, select
|
||||||
@ -19,19 +20,22 @@ from app.assets.database.queries.common import (
|
|||||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||||
|
|
||||||
|
|
||||||
class AddTagsDict(TypedDict):
|
@dataclass(frozen=True)
|
||||||
|
class AddTagsResult:
|
||||||
added: list[str]
|
added: list[str]
|
||||||
already_present: list[str]
|
already_present: list[str]
|
||||||
total_tags: list[str]
|
total_tags: list[str]
|
||||||
|
|
||||||
|
|
||||||
class RemoveTagsDict(TypedDict):
|
@dataclass(frozen=True)
|
||||||
|
class RemoveTagsResult:
|
||||||
removed: list[str]
|
removed: list[str]
|
||||||
not_present: list[str]
|
not_present: list[str]
|
||||||
total_tags: list[str]
|
total_tags: list[str]
|
||||||
|
|
||||||
|
|
||||||
class SetTagsDict(TypedDict):
|
@dataclass(frozen=True)
|
||||||
|
class SetTagsResult:
|
||||||
added: list[str]
|
added: list[str]
|
||||||
removed: list[str]
|
removed: list[str]
|
||||||
total: list[str]
|
total: list[str]
|
||||||
@ -81,19 +85,10 @@ def set_reference_tags(
|
|||||||
reference_id: str,
|
reference_id: str,
|
||||||
tags: Sequence[str],
|
tags: Sequence[str],
|
||||||
origin: str = "manual",
|
origin: str = "manual",
|
||||||
) -> SetTagsDict:
|
) -> SetTagsResult:
|
||||||
desired = normalize_tags(tags)
|
desired = normalize_tags(tags)
|
||||||
|
|
||||||
current = set(
|
current = set(get_reference_tags(session, reference_id))
|
||||||
tag_name
|
|
||||||
for (tag_name,) in (
|
|
||||||
session.execute(
|
|
||||||
select(AssetReferenceTag.tag_name).where(
|
|
||||||
AssetReferenceTag.asset_reference_id == reference_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
)
|
|
||||||
|
|
||||||
to_add = [t for t in desired if t not in current]
|
to_add = [t for t in desired if t not in current]
|
||||||
to_remove = [t for t in current if t not in desired]
|
to_remove = [t for t in current if t not in desired]
|
||||||
@ -122,7 +117,7 @@ def set_reference_tags(
|
|||||||
)
|
)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
|
||||||
|
|
||||||
|
|
||||||
def add_tags_to_reference(
|
def add_tags_to_reference(
|
||||||
@ -132,7 +127,7 @@ def add_tags_to_reference(
|
|||||||
origin: str = "manual",
|
origin: str = "manual",
|
||||||
create_if_missing: bool = True,
|
create_if_missing: bool = True,
|
||||||
reference_row: AssetReference | None = None,
|
reference_row: AssetReference | None = None,
|
||||||
) -> AddTagsDict:
|
) -> AddTagsResult:
|
||||||
if not reference_row:
|
if not reference_row:
|
||||||
ref = session.get(AssetReference, reference_id)
|
ref = session.get(AssetReference, reference_id)
|
||||||
if not ref:
|
if not ref:
|
||||||
@ -141,21 +136,12 @@ def add_tags_to_reference(
|
|||||||
norm = normalize_tags(tags)
|
norm = normalize_tags(tags)
|
||||||
if not norm:
|
if not norm:
|
||||||
total = get_reference_tags(session, reference_id=reference_id)
|
total = get_reference_tags(session, reference_id=reference_id)
|
||||||
return {"added": [], "already_present": [], "total_tags": total}
|
return AddTagsResult(added=[], already_present=[], total_tags=total)
|
||||||
|
|
||||||
if create_if_missing:
|
if create_if_missing:
|
||||||
ensure_tags_exist(session, norm, tag_type="user")
|
ensure_tags_exist(session, norm, tag_type="user")
|
||||||
|
|
||||||
current = {
|
current = set(get_reference_tags(session, reference_id))
|
||||||
tag_name
|
|
||||||
for (tag_name,) in (
|
|
||||||
session.execute(
|
|
||||||
sa.select(AssetReferenceTag.tag_name).where(
|
|
||||||
AssetReferenceTag.asset_reference_id == reference_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
}
|
|
||||||
|
|
||||||
want = set(norm)
|
want = set(norm)
|
||||||
to_add = sorted(want - current)
|
to_add = sorted(want - current)
|
||||||
@ -179,18 +165,18 @@ def add_tags_to_reference(
|
|||||||
nested.rollback()
|
nested.rollback()
|
||||||
|
|
||||||
after = set(get_reference_tags(session, reference_id=reference_id))
|
after = set(get_reference_tags(session, reference_id=reference_id))
|
||||||
return {
|
return AddTagsResult(
|
||||||
"added": sorted(((after - current) & want)),
|
added=sorted(((after - current) & want)),
|
||||||
"already_present": sorted(want & current),
|
already_present=sorted(want & current),
|
||||||
"total_tags": sorted(after),
|
total_tags=sorted(after),
|
||||||
}
|
)
|
||||||
|
|
||||||
|
|
||||||
def remove_tags_from_reference(
|
def remove_tags_from_reference(
|
||||||
session: Session,
|
session: Session,
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
tags: Sequence[str],
|
tags: Sequence[str],
|
||||||
) -> RemoveTagsDict:
|
) -> RemoveTagsResult:
|
||||||
ref = session.get(AssetReference, reference_id)
|
ref = session.get(AssetReference, reference_id)
|
||||||
if not ref:
|
if not ref:
|
||||||
raise ValueError(f"AssetReference {reference_id} not found")
|
raise ValueError(f"AssetReference {reference_id} not found")
|
||||||
@ -198,18 +184,9 @@ def remove_tags_from_reference(
|
|||||||
norm = normalize_tags(tags)
|
norm = normalize_tags(tags)
|
||||||
if not norm:
|
if not norm:
|
||||||
total = get_reference_tags(session, reference_id=reference_id)
|
total = get_reference_tags(session, reference_id=reference_id)
|
||||||
return {"removed": [], "not_present": [], "total_tags": total}
|
return RemoveTagsResult(removed=[], not_present=[], total_tags=total)
|
||||||
|
|
||||||
existing = {
|
existing = set(get_reference_tags(session, reference_id))
|
||||||
tag_name
|
|
||||||
for (tag_name,) in (
|
|
||||||
session.execute(
|
|
||||||
sa.select(AssetReferenceTag.tag_name).where(
|
|
||||||
AssetReferenceTag.asset_reference_id == reference_id
|
|
||||||
)
|
|
||||||
)
|
|
||||||
).all()
|
|
||||||
}
|
|
||||||
|
|
||||||
to_remove = sorted(set(t for t in norm if t in existing))
|
to_remove = sorted(set(t for t in norm if t in existing))
|
||||||
not_present = sorted(set(t for t in norm if t not in existing))
|
not_present = sorted(set(t for t in norm if t not in existing))
|
||||||
@ -224,7 +201,7 @@ def remove_tags_from_reference(
|
|||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
total = get_reference_tags(session, reference_id=reference_id)
|
total = get_reference_tags(session, reference_id=reference_id)
|
||||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total)
|
||||||
|
|
||||||
|
|
||||||
def add_missing_tag_for_asset_id(
|
def add_missing_tag_for_asset_id(
|
||||||
|
|||||||
@ -45,3 +45,21 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
|
|||||||
- Removing duplicates.
|
- Removing duplicates.
|
||||||
"""
|
"""
|
||||||
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
|
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
|
||||||
|
|
||||||
|
|
||||||
|
def validate_blake3_hash(s: str) -> str:
|
||||||
|
"""Validate and normalize a blake3 hash string.
|
||||||
|
|
||||||
|
Returns canonical 'blake3:<hex>' or raises ValueError.
|
||||||
|
"""
|
||||||
|
s = s.strip().lower()
|
||||||
|
if not s or ":" not in s:
|
||||||
|
raise ValueError("hash must be 'blake3:<hex>'")
|
||||||
|
algo, digest = s.split(":", 1)
|
||||||
|
if (
|
||||||
|
algo != "blake3"
|
||||||
|
or len(digest) != 64
|
||||||
|
or any(c for c in digest if c not in "0123456789abcdef")
|
||||||
|
):
|
||||||
|
raise ValueError("hash must be 'blake3:<hex>'")
|
||||||
|
return f"{algo}:{digest}"
|
||||||
|
|||||||
@ -1,6 +1,5 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import time
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Literal, TypedDict
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
@ -16,6 +15,7 @@ from app.assets.database.queries import (
|
|||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_references_for_prefixes,
|
get_references_for_prefixes,
|
||||||
get_unenriched_references,
|
get_unenriched_references,
|
||||||
|
mark_references_missing_outside_prefixes,
|
||||||
reassign_asset_references,
|
reassign_asset_references,
|
||||||
remove_missing_tag_for_asset_id,
|
remove_missing_tag_for_asset_id,
|
||||||
set_reference_metadata,
|
set_reference_metadata,
|
||||||
@ -24,7 +24,6 @@ from app.assets.database.queries import (
|
|||||||
from app.assets.services.bulk_ingest import (
|
from app.assets.services.bulk_ingest import (
|
||||||
SeedAssetSpec,
|
SeedAssetSpec,
|
||||||
batch_insert_seed_assets,
|
batch_insert_seed_assets,
|
||||||
mark_assets_missing_outside_prefixes,
|
|
||||||
)
|
)
|
||||||
from app.assets.services.file_utils import (
|
from app.assets.services.file_utils import (
|
||||||
get_mtime_ns,
|
get_mtime_ns,
|
||||||
@ -39,7 +38,7 @@ from app.assets.services.path_utils import (
|
|||||||
get_comfy_models_folders,
|
get_comfy_models_folders,
|
||||||
get_name_and_tags_from_asset_path,
|
get_name_and_tags_from_asset_path,
|
||||||
)
|
)
|
||||||
from app.database.db import create_session, dependencies_available
|
from app.database.db import create_session
|
||||||
|
|
||||||
|
|
||||||
class _RefInfo(TypedDict):
|
class _RefInfo(TypedDict):
|
||||||
@ -257,7 +256,7 @@ def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int:
|
|||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
with create_session() as sess:
|
with create_session() as sess:
|
||||||
count = mark_assets_missing_outside_prefixes(sess, prefixes)
|
count = mark_references_missing_outside_prefixes(sess, prefixes)
|
||||||
sess.commit()
|
sess.commit()
|
||||||
return count
|
return count
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -438,11 +437,17 @@ def enrich_asset(
|
|||||||
full_hash: str | None = None
|
full_hash: str | None = None
|
||||||
if compute_hash:
|
if compute_hash:
|
||||||
try:
|
try:
|
||||||
|
mtime_before = get_mtime_ns(stat_p)
|
||||||
digest = compute_blake3_hash(file_path)
|
digest = compute_blake3_hash(file_path)
|
||||||
full_hash = f"blake3:{digest}"
|
stat_after = os.stat(file_path, follow_symlinks=True)
|
||||||
metadata_ok = not extract_metadata or metadata is not None
|
mtime_after = get_mtime_ns(stat_after)
|
||||||
if metadata_ok:
|
if mtime_before != mtime_after:
|
||||||
new_level = ENRICHMENT_HASHED
|
logging.warning("File modified during hashing, discarding hash: %s", file_path)
|
||||||
|
else:
|
||||||
|
full_hash = f"blake3:{digest}"
|
||||||
|
metadata_ok = not extract_metadata or metadata is not None
|
||||||
|
if metadata_ok:
|
||||||
|
new_level = ENRICHMENT_HASHED
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||||
|
|
||||||
|
|||||||
@ -12,7 +12,6 @@ from app.assets.services.bulk_ingest import (
|
|||||||
BulkInsertResult,
|
BulkInsertResult,
|
||||||
batch_insert_seed_assets,
|
batch_insert_seed_assets,
|
||||||
cleanup_unreferenced_assets,
|
cleanup_unreferenced_assets,
|
||||||
mark_assets_missing_outside_prefixes,
|
|
||||||
)
|
)
|
||||||
from app.assets.services.file_utils import (
|
from app.assets.services.file_utils import (
|
||||||
get_mtime_ns,
|
get_mtime_ns,
|
||||||
@ -26,8 +25,11 @@ from app.assets.services.ingest import (
|
|||||||
create_from_hash,
|
create_from_hash,
|
||||||
upload_from_temp_path,
|
upload_from_temp_path,
|
||||||
)
|
)
|
||||||
from app.assets.services.schemas import (
|
from app.assets.database.queries import (
|
||||||
AddTagsResult,
|
AddTagsResult,
|
||||||
|
RemoveTagsResult,
|
||||||
|
)
|
||||||
|
from app.assets.services.schemas import (
|
||||||
AssetData,
|
AssetData,
|
||||||
AssetDetailResult,
|
AssetDetailResult,
|
||||||
AssetSummaryData,
|
AssetSummaryData,
|
||||||
@ -36,7 +38,6 @@ from app.assets.services.schemas import (
|
|||||||
ListAssetsResult,
|
ListAssetsResult,
|
||||||
ReferenceData,
|
ReferenceData,
|
||||||
RegisterAssetResult,
|
RegisterAssetResult,
|
||||||
RemoveTagsResult,
|
|
||||||
TagUsage,
|
TagUsage,
|
||||||
UploadResult,
|
UploadResult,
|
||||||
UserMetadata,
|
UserMetadata,
|
||||||
@ -77,7 +78,6 @@ __all__ = [
|
|||||||
"list_files_recursively",
|
"list_files_recursively",
|
||||||
"list_tags",
|
"list_tags",
|
||||||
"cleanup_unreferenced_assets",
|
"cleanup_unreferenced_assets",
|
||||||
"mark_assets_missing_outside_prefixes",
|
|
||||||
"remove_tags",
|
"remove_tags",
|
||||||
"resolve_asset_for_download",
|
"resolve_asset_for_download",
|
||||||
"set_asset_preview",
|
"set_asset_preview",
|
||||||
|
|||||||
@ -13,6 +13,7 @@ from app.assets.database.queries import (
|
|||||||
fetch_reference_asset_and_tags,
|
fetch_reference_asset_and_tags,
|
||||||
get_asset_by_hash as queries_get_asset_by_hash,
|
get_asset_by_hash as queries_get_asset_by_hash,
|
||||||
get_reference_by_id,
|
get_reference_by_id,
|
||||||
|
get_reference_with_owner_check,
|
||||||
list_references_page,
|
list_references_page,
|
||||||
list_references_by_asset_id,
|
list_references_by_asset_id,
|
||||||
set_reference_metadata,
|
set_reference_metadata,
|
||||||
@ -23,7 +24,7 @@ from app.assets.database.queries import (
|
|||||||
update_reference_updated_at,
|
update_reference_updated_at,
|
||||||
)
|
)
|
||||||
from app.assets.helpers import select_best_live_path
|
from app.assets.helpers import select_best_live_path
|
||||||
from app.assets.services.path_utils import compute_filename_for_reference
|
from app.assets.services.path_utils import compute_relative_filename
|
||||||
from app.assets.services.schemas import (
|
from app.assets.services.schemas import (
|
||||||
AssetData,
|
AssetData,
|
||||||
AssetDetailResult,
|
AssetDetailResult,
|
||||||
@ -67,18 +68,14 @@ def update_asset_metadata(
|
|||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
) -> AssetDetailResult:
|
) -> AssetDetailResult:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
ref = get_reference_by_id(session, reference_id=reference_id)
|
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||||
if not ref:
|
|
||||||
raise ValueError(f"AssetReference {reference_id} not found")
|
|
||||||
if ref.owner_id and ref.owner_id != owner_id:
|
|
||||||
raise PermissionError("not owner")
|
|
||||||
|
|
||||||
touched = False
|
touched = False
|
||||||
if name is not None and name != ref.name:
|
if name is not None and name != ref.name:
|
||||||
update_reference_name(session, reference_id=reference_id, name=name)
|
update_reference_name(session, reference_id=reference_id, name=name)
|
||||||
touched = True
|
touched = True
|
||||||
|
|
||||||
computed_filename = compute_filename_for_reference(session, ref)
|
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||||
|
|
||||||
new_meta: dict | None = None
|
new_meta: dict | None = None
|
||||||
if user_metadata is not None:
|
if user_metadata is not None:
|
||||||
@ -183,11 +180,7 @@ def set_asset_preview(
|
|||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
) -> AssetDetailResult:
|
) -> AssetDetailResult:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
ref_row = get_reference_by_id(session, reference_id=reference_id)
|
get_reference_with_owner_check(session, reference_id, owner_id)
|
||||||
if not ref_row:
|
|
||||||
raise ValueError(f"AssetReference {reference_id} not found")
|
|
||||||
if ref_row.owner_id and ref_row.owner_id != owner_id:
|
|
||||||
raise PermissionError("not owner")
|
|
||||||
|
|
||||||
set_reference_preview(
|
set_reference_preview(
|
||||||
session,
|
session,
|
||||||
|
|||||||
@ -17,7 +17,6 @@ from app.assets.database.queries import (
|
|||||||
get_reference_ids_by_ids,
|
get_reference_ids_by_ids,
|
||||||
get_references_by_paths_and_asset_ids,
|
get_references_by_paths_and_asset_ids,
|
||||||
get_unreferenced_unhashed_asset_ids,
|
get_unreferenced_unhashed_asset_ids,
|
||||||
mark_references_missing_outside_prefixes,
|
|
||||||
restore_references_by_paths,
|
restore_references_by_paths,
|
||||||
)
|
)
|
||||||
from app.assets.helpers import get_utc_now
|
from app.assets.helpers import get_utc_now
|
||||||
@ -266,25 +265,6 @@ def batch_insert_seed_assets(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def mark_assets_missing_outside_prefixes(
|
|
||||||
session: Session, valid_prefixes: list[str]
|
|
||||||
) -> int:
|
|
||||||
"""Mark references as missing when outside valid prefixes.
|
|
||||||
|
|
||||||
This is a non-destructive operation that soft-deletes references
|
|
||||||
by setting is_missing=True. User metadata is preserved and assets
|
|
||||||
can be restored if the file reappears in a future scan.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
session: Database session
|
|
||||||
valid_prefixes: List of absolute directory prefixes that are valid
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
Number of references marked as missing
|
|
||||||
"""
|
|
||||||
return mark_references_missing_outside_prefixes(session, valid_prefixes)
|
|
||||||
|
|
||||||
|
|
||||||
def cleanup_unreferenced_assets(session: Session) -> int:
|
def cleanup_unreferenced_assets(session: Session) -> int:
|
||||||
"""Hard-delete unhashed assets with no active references.
|
"""Hard-delete unhashed assets with no active references.
|
||||||
|
|
||||||
|
|||||||
@ -25,7 +25,6 @@ from app.assets.database.queries import (
|
|||||||
from app.assets.helpers import normalize_tags
|
from app.assets.helpers import normalize_tags
|
||||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||||
from app.assets.services.path_utils import (
|
from app.assets.services.path_utils import (
|
||||||
compute_filename_for_reference,
|
|
||||||
compute_relative_filename,
|
compute_relative_filename,
|
||||||
resolve_destination_from_tags,
|
resolve_destination_from_tags,
|
||||||
validate_path_within_base,
|
validate_path_within_base,
|
||||||
@ -163,7 +162,7 @@ def _register_existing_asset(
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
new_meta = dict(user_metadata)
|
new_meta = dict(user_metadata)
|
||||||
computed_filename = compute_filename_for_reference(session, ref)
|
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||||
if computed_filename:
|
if computed_filename:
|
||||||
new_meta["filename"] = computed_filename
|
new_meta["filename"] = computed_filename
|
||||||
|
|
||||||
|
|||||||
@ -150,16 +150,6 @@ def get_asset_category_and_relative_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def compute_filename_for_reference(session, ref) -> str | None:
|
|
||||||
"""Compute the relative filename for an asset reference.
|
|
||||||
|
|
||||||
Uses the file_path from the reference if available.
|
|
||||||
"""
|
|
||||||
if ref.file_path:
|
|
||||||
return compute_relative_filename(ref.file_path)
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||||
"""Return (name, tags) derived from a filesystem path.
|
"""Return (name, tags) derived from a filesystem path.
|
||||||
|
|
||||||
|
|||||||
@ -52,20 +52,6 @@ class IngestResult:
|
|||||||
reference_id: str | None
|
reference_id: str | None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class AddTagsResult:
|
|
||||||
added: list[str]
|
|
||||||
already_present: list[str]
|
|
||||||
total_tags: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
|
||||||
class RemoveTagsResult:
|
|
||||||
removed: list[str]
|
|
||||||
not_present: list[str]
|
|
||||||
total_tags: list[str]
|
|
||||||
|
|
||||||
|
|
||||||
class TagUsage(NamedTuple):
|
class TagUsage(NamedTuple):
|
||||||
name: str
|
name: str
|
||||||
tag_type: str
|
tag_type: str
|
||||||
|
|||||||
@ -1,10 +1,12 @@
|
|||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
|
AddTagsResult,
|
||||||
|
RemoveTagsResult,
|
||||||
add_tags_to_reference,
|
add_tags_to_reference,
|
||||||
get_reference_by_id,
|
get_reference_with_owner_check,
|
||||||
list_tags_with_usage,
|
list_tags_with_usage,
|
||||||
remove_tags_from_reference,
|
remove_tags_from_reference,
|
||||||
)
|
)
|
||||||
from app.assets.services.schemas import AddTagsResult, RemoveTagsResult, TagUsage
|
from app.assets.services.schemas import TagUsage
|
||||||
from app.database.db import create_session
|
from app.database.db import create_session
|
||||||
|
|
||||||
|
|
||||||
@ -15,13 +17,9 @@ def apply_tags(
|
|||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
) -> AddTagsResult:
|
) -> AddTagsResult:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
ref_row = get_reference_by_id(session, reference_id=reference_id)
|
ref_row = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||||
if not ref_row:
|
|
||||||
raise ValueError(f"AssetReference {reference_id} not found")
|
|
||||||
if ref_row.owner_id and ref_row.owner_id != owner_id:
|
|
||||||
raise PermissionError("not owner")
|
|
||||||
|
|
||||||
data = add_tags_to_reference(
|
result = add_tags_to_reference(
|
||||||
session,
|
session,
|
||||||
reference_id=reference_id,
|
reference_id=reference_id,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
@ -31,11 +29,7 @@ def apply_tags(
|
|||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
return AddTagsResult(
|
return result
|
||||||
added=data["added"],
|
|
||||||
already_present=data["already_present"],
|
|
||||||
total_tags=data["total_tags"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def remove_tags(
|
def remove_tags(
|
||||||
@ -44,24 +38,16 @@ def remove_tags(
|
|||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
) -> RemoveTagsResult:
|
) -> RemoveTagsResult:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
ref_row = get_reference_by_id(session, reference_id=reference_id)
|
get_reference_with_owner_check(session, reference_id, owner_id)
|
||||||
if not ref_row:
|
|
||||||
raise ValueError(f"AssetReference {reference_id} not found")
|
|
||||||
if ref_row.owner_id and ref_row.owner_id != owner_id:
|
|
||||||
raise PermissionError("not owner")
|
|
||||||
|
|
||||||
data = remove_tags_from_reference(
|
result = remove_tags_from_reference(
|
||||||
session,
|
session,
|
||||||
reference_id=reference_id,
|
reference_id=reference_id,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
)
|
)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
return RemoveTagsResult(
|
return result
|
||||||
removed=data["removed"],
|
|
||||||
not_present=data["not_present"],
|
|
||||||
total_tags=data["total_tags"],
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def list_tags(
|
def list_tags(
|
||||||
|
|||||||
4
main.py
4
main.py
@ -7,6 +7,7 @@ import folder_paths
|
|||||||
import time
|
import time
|
||||||
from comfy.cli_args import args, enables_dynamic_vram
|
from comfy.cli_args import args, enables_dynamic_vram
|
||||||
from app.logger import setup_logger
|
from app.logger import setup_logger
|
||||||
|
from app.assets.api.routes import disable_assets_routes
|
||||||
from app.assets.seeder import asset_seeder
|
from app.assets.seeder import asset_seeder
|
||||||
import itertools
|
import itertools
|
||||||
import utils.extra_config
|
import utils.extra_config
|
||||||
@ -364,6 +365,9 @@ def setup_database():
|
|||||||
logging.info("Background asset scan initiated for models, input, output")
|
logging.info("Background asset scan initiated for models, input, output")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||||
|
if args.enable_assets:
|
||||||
|
disable_assets_routes()
|
||||||
|
asset_seeder.disable()
|
||||||
|
|
||||||
|
|
||||||
def start_comfyui(asyncio_loop=None):
|
def start_comfyui(asyncio_loop=None):
|
||||||
|
|||||||
@ -104,9 +104,9 @@ class TestSetReferenceTags:
|
|||||||
result = set_reference_tags(session, reference_id=ref.id, tags=["a", "b"])
|
result = set_reference_tags(session, reference_id=ref.id, tags=["a", "b"])
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
assert set(result["added"]) == {"a", "b"}
|
assert set(result.added) == {"a", "b"}
|
||||||
assert result["removed"] == []
|
assert result.removed == []
|
||||||
assert set(result["total"]) == {"a", "b"}
|
assert set(result.total) == {"a", "b"}
|
||||||
|
|
||||||
def test_removes_old_tags(self, session: Session):
|
def test_removes_old_tags(self, session: Session):
|
||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
@ -116,9 +116,9 @@ class TestSetReferenceTags:
|
|||||||
result = set_reference_tags(session, reference_id=ref.id, tags=["a"])
|
result = set_reference_tags(session, reference_id=ref.id, tags=["a"])
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
assert result["added"] == []
|
assert result.added == []
|
||||||
assert set(result["removed"]) == {"b", "c"}
|
assert set(result.removed) == {"b", "c"}
|
||||||
assert result["total"] == ["a"]
|
assert result.total == ["a"]
|
||||||
|
|
||||||
def test_replaces_tags(self, session: Session):
|
def test_replaces_tags(self, session: Session):
|
||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
@ -128,9 +128,9 @@ class TestSetReferenceTags:
|
|||||||
result = set_reference_tags(session, reference_id=ref.id, tags=["b", "c"])
|
result = set_reference_tags(session, reference_id=ref.id, tags=["b", "c"])
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
assert result["added"] == ["c"]
|
assert result.added == ["c"]
|
||||||
assert result["removed"] == ["a"]
|
assert result.removed == ["a"]
|
||||||
assert set(result["total"]) == {"b", "c"}
|
assert set(result.total) == {"b", "c"}
|
||||||
|
|
||||||
|
|
||||||
class TestAddTagsToReference:
|
class TestAddTagsToReference:
|
||||||
@ -141,8 +141,8 @@ class TestAddTagsToReference:
|
|||||||
result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"])
|
result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"])
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
assert set(result["added"]) == {"x", "y"}
|
assert set(result.added) == {"x", "y"}
|
||||||
assert result["already_present"] == []
|
assert result.already_present == []
|
||||||
|
|
||||||
def test_reports_already_present(self, session: Session):
|
def test_reports_already_present(self, session: Session):
|
||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
@ -152,8 +152,8 @@ class TestAddTagsToReference:
|
|||||||
result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"])
|
result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"])
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
assert result["added"] == ["y"]
|
assert result.added == ["y"]
|
||||||
assert result["already_present"] == ["x"]
|
assert result.already_present == ["x"]
|
||||||
|
|
||||||
def test_raises_for_missing_reference(self, session: Session):
|
def test_raises_for_missing_reference(self, session: Session):
|
||||||
with pytest.raises(ValueError, match="not found"):
|
with pytest.raises(ValueError, match="not found"):
|
||||||
@ -169,9 +169,9 @@ class TestRemoveTagsFromReference:
|
|||||||
result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "b"])
|
result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "b"])
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
assert set(result["removed"]) == {"a", "b"}
|
assert set(result.removed) == {"a", "b"}
|
||||||
assert result["not_present"] == []
|
assert result.not_present == []
|
||||||
assert result["total_tags"] == ["c"]
|
assert result.total_tags == ["c"]
|
||||||
|
|
||||||
def test_reports_not_present(self, session: Session):
|
def test_reports_not_present(self, session: Session):
|
||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
@ -181,8 +181,8 @@ class TestRemoveTagsFromReference:
|
|||||||
result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "x"])
|
result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "x"])
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
assert result["removed"] == ["a"]
|
assert result.removed == ["a"]
|
||||||
assert result["not_present"] == ["x"]
|
assert result.not_present == ["x"]
|
||||||
|
|
||||||
def test_raises_for_missing_reference(self, session: Session):
|
def test_raises_for_missing_reference(self, session: Session):
|
||||||
with pytest.raises(ValueError, match="not found"):
|
with pytest.raises(ValueError, match="not found"):
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user