diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 252242eae..001dfa324 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -283,7 +283,7 @@ async def upload_asset(request: web.Request) -> web.Response: "Uploaded file hash does not match provided hash.", ) return _error_response(400, "BAD_REQUEST", "Invalid inputs.") - except Exception as e: + except Exception: if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) return _error_response(500, "INTERNAL", "Unexpected server error.") diff --git a/app/database/_helpers.py b/app/database/_helpers.py index 5ce972076..a031e861c 100644 --- a/app/database/_helpers.py +++ b/app/database/_helpers.py @@ -67,32 +67,29 @@ def apply_metadata_filter( return stmt def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - subquery = ( - select(sa.literal(1)) - .select_from(AssetInfoMeta) - .where( - AssetInfoMeta.asset_info_id == AssetInfo.id, - AssetInfoMeta.key == key, - *preds, - ) - .limit(1) + return sa.exists().where( + AssetInfoMeta.asset_info_id == AssetInfo.id, + AssetInfoMeta.key == key, + *preds, ) - return sa.exists(subquery) def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: # Missing OR null: if value is None: # either: no row for key OR a row for key with explicit null - no_row_for_key = ~sa.exists( - select(sa.literal(1)) - .select_from(AssetInfoMeta) - .where( + no_row_for_key = sa.not_( + sa.exists().where( AssetInfoMeta.asset_info_id == AssetInfo.id, AssetInfoMeta.key == key, ) - .limit(1) ) - null_row = _exists_for_pred(key, AssetInfoMeta.val_json.is_(None)) + null_row = _exists_for_pred( + key, + AssetInfoMeta.val_json.is_(None), + AssetInfoMeta.val_str.is_(None), + AssetInfoMeta.val_num.is_(None), + AssetInfoMeta.val_bool.is_(None), + ) return sa.or_(no_row_for_key, null_row) # Typed scalar matches: @@ -135,13 +132,19 @@ def project_kv(key: str, value: Any) -> list[dict]: - scalar -> one row (ordinal=0) in the proper typed column - list of scalars -> one row per element with ordinal=i - dict or list with non-scalars -> single row with val_json (or one per element w/ val_json if list) - - None -> single row with val_json = None + - None -> single row with all value columns NULL Each row: {"key": key, "ordinal": i, "val_str"/"val_num"/"val_bool"/"val_json": ...} """ rows: list[dict] = [] + def _null_row(ordinal: int) -> dict: + return { + "key": key, "ordinal": ordinal, + "val_str": None, "val_num": None, "val_bool": None, "val_json": None + } + if value is None: - rows.append({"key": key, "ordinal": 0, "val_json": None}) + rows.append(_null_row(0)) return rows if is_scalar(value): @@ -162,7 +165,7 @@ def project_kv(key: str, value: Any) -> list[dict]: if all(is_scalar(x) for x in value): for i, x in enumerate(value): if x is None: - rows.append({"key": key, "ordinal": i, "val_json": None}) + rows.append(_null_row(i)) elif isinstance(x, bool): rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) elif isinstance(x, (int, float, Decimal)): diff --git a/app/database/models.py b/app/database/models.py index 203867468..5bb3a09bc 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -145,7 +145,7 @@ class AssetInfo(Base): String(256), ForeignKey("assets.hash", ondelete="RESTRICT"), nullable=False ) preview_hash: Mapped[str | None] = mapped_column(String(256), ForeignKey("assets.hash", ondelete="SET NULL")) - user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=False), nullable=False, default=utcnow ) @@ -220,7 +220,7 @@ class AssetInfoMeta(Base): val_str: Mapped[Optional[str]] = mapped_column(String(2048), nullable=True) val_num: Mapped[Optional[float]] = mapped_column(Numeric(38, 10), nullable=True) val_bool: Mapped[Optional[bool]] = mapped_column(Boolean, nullable=True) - val_json: Mapped[Optional[Any]] = mapped_column(JSON, nullable=True) + val_json: Mapped[Optional[Any]] = mapped_column(JSON(none_as_null=True), nullable=True) asset_info: Mapped["AssetInfo"] = relationship(back_populates="metadata_entries") diff --git a/tests-assets/test_metadata_filters.py b/tests-assets/test_metadata_filters.py new file mode 100644 index 000000000..39d00fa2d --- /dev/null +++ b/tests-assets/test_metadata_filters.py @@ -0,0 +1,378 @@ +import json +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_meta_and_across_keys_and_types(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes): + name = "mf_and_mix.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-and"] + meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + await asset_factory(name, tags, meta, make_asset_bytes(name, 4096)) + + # All keys must match (AND semantics) + f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_ok), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names = [a["name"] for a in b1["assets"]] + assert name in names + + # One key mismatched -> no result + f_bad = {"purpose": "mix", "epoch": 2, "active": True} + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_bad), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + assert not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes): + name = "mf_types.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-types"] + meta = {"epoch": 1, "active": True} + await asset_factory(name, tags, meta, make_asset_bytes(name)) + + # int filter matches numeric + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": 1}), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # string "1" must NOT match numeric 1 + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": "1"}), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + # bool True matches, string "true" must NOT match + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": True}), + }, + ) as r3: + b3 = await r3.json() + assert r3.status == 200 and any(a["name"] == name for a in b3["assets"]) + + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": "true"}), + }, + ) as r4: + b4 = await r4.json() + assert r4.status == 200 and not b4["assets"] + + +@pytest.mark.asyncio +async def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_scalars.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-list"] + meta = {"flags": ["red", "green"]} + await asset_factory(name, tags, meta, make_asset_bytes(name, 3000)) + + # Any-of should match because "green" is present + filt_ok = {"flags": ["blue", "green"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # None of provided flags present -> no match + filt_miss = {"flags": ["blue", "yellow"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + # Duplicates in list should not break matching + filt_dup = {"flags": ["green", "green", "green"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)}, + ) as r3: + b3 = await r3.json() + assert r3.status == 200 and any(a["name"] == name for a in b3["assets"]) + + +@pytest.mark.asyncio +async def test_meta_none_semantics_missing_or_null_and_any_of_with_none(http, api_base, asset_factory, make_asset_bytes): + # a1: key missing; a2: explicit null; a3: concrete value + t = ["models", "checkpoints", "unit-tests", "mf-none"] + a1 = await asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1")) + a2 = await asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2")) + a3 = await asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3")) + + # Filter {maybe: None} must match a1 and a2, not a3 + filt = {"maybe": None} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + got = [a["name"] for a in b1["assets"]] + assert a1["name"] in got and a2["name"] in got and a3["name"] not in got + + # Any-of with None should include missing/null plus value matches + filt_any = {"maybe": [None, "x"]} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 + got2 = [a["name"] for a in b2["assets"]] + assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2 + + +@pytest.mark.asyncio +async def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes): + name = "mf_nested_json.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-nested"] + cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}} + await asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200)) + + # Exact JSON object equality (same structure) + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": cfg}), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Different JSON object should not match + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_objects.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-objlist"] + transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}] + await asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048)) + + # Any-of for list of objects should match when one element equals the filter object + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}), + }, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Non-matching object -> no match + async with http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}), + }, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes): + name = "mf_keys_unicode.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-keys"] + meta = { + "weird.key": "v1", + "path/like": 7, + "with:colon": True, + "ключ": "значение", + "emoji": "🐍", + } + await asset_factory(name, tags, meta, make_asset_bytes(name, 1500)) + + # Match all the special keys + filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"} + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Unicode key match + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and any(a["name"] == name for a in b2["assets"]) + + +@pytest.mark.asyncio +async def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"] + a0 = await asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025)) + a1 = await asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026)) + + # count == 0 must match only a0 + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names1 = [a["name"] for a in b1["assets"]] + assert a0["name"] in names1 and a1["name"] not in names1 + + # Any-of list of booleans: True matches second asset + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and any(a["name"] == a1["name"] for a in b2["assets"]) + + +@pytest.mark.asyncio +async def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes): + name = "mf_mixed_list.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-mixed"] + meta = {"mix": ["1", 1, True, None]} + await asset_factory(name, tags, meta, make_asset_bytes(name, 1999)) + + # Should match because 1 is present + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 and any(a["name"] == name for a in b1["assets"]) + + # Should NOT match for False + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes): + # Use a unique scope tag to avoid interference + t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"] + x = await asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua")) + y = await asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub")) + + # Filtering by unknown key with None should return both (missing key OR null) + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})}, + ) as r1: + b1 = await r1.json() + assert r1.status == 200 + names = {a["name"] for a in b1["assets"]} + assert x["name"] in names and y["name"] in names + + # Filtering by unknown key with concrete value should return none + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})}, + ) as r2: + b2 = await r2.json() + assert r2.status == 200 and not b2["assets"] + + +@pytest.mark.asyncio +async def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes): + # alpha matches epoch=1; beta has epoch=2 + a = await asset_factory( + "mf_tag_alpha.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "alpha"], + {"epoch": 1}, + make_asset_bytes("alpha"), + ) + b = await asset_factory( + "mf_tag_beta.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "beta"], + {"epoch": 2}, + make_asset_bytes("beta"), + ) + + params = { + "include_tags": "unit-tests,mf-tag,alpha", + "exclude_tags": "beta", + "name_contains": "mf_tag_", + "metadata_filter": json.dumps({"epoch": 1}), + } + async with http.get(api_base + "/api/assets", params=params) as r: + body = await r.json() + assert r.status == 200 + names = [x["name"] for x in body["assets"]] + assert a["name"] in names + assert b["name"] not in names + + +@pytest.mark.asyncio +async def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes): + # Three assets in same scope with different sizes and a common filter key + t = ["models", "checkpoints", "unit-tests", "mf-sort"] + n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors" + await asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024)) + await asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048)) + await asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072)) + + # Sort by size ascending with paging + q = {"include_tags": "unit-tests,mf-sort", "metadata_filter": json.dumps({"group": "g"}), "sort": "size", "order": "asc", "limit": "2"} + async with http.get(api_base + "/api/assets", params=q) as r1: + b1 = await r1.json() + assert r1.status == 200 + got1 = [a["name"] for a in b1["assets"]] + assert got1 == [n1, n2] + assert b1["has_more"] is True + + q2 = {**q, "offset": "2"} + async with http.get(api_base + "/api/assets", params=q2) as r2: + b2 = await r2.json() + assert r2.status == 200 + got2 = [a["name"] for a in b2["assets"]] + assert got2 == [n3] + assert b2["has_more"] is False