fix+test: escape "_" symbol in assets filtering

This commit is contained in:
bigcat88 2025-09-15 19:19:47 +03:00
parent f3cf99d10c
commit f1fb7432a0
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
2 changed files with 38 additions and 2 deletions

View File

@ -47,7 +47,8 @@ async def list_asset_infos_page(
) )
if name_contains: if name_contains:
base = base.where(AssetInfo.name.ilike(f"%{name_contains}%")) escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags) base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter) base = apply_metadata_filter(base, metadata_filter)
@ -73,7 +74,8 @@ async def list_asset_infos_page(
.where(visible_owner_clause(owner_id)) .where(visible_owner_clause(owner_id))
) )
if name_contains: if name_contains:
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{name_contains}%")) escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter) count_stmt = apply_metadata_filter(count_stmt, metadata_filter)

View File

@ -1,4 +1,5 @@
import asyncio import asyncio
import uuid
import aiohttp import aiohttp
import pytest import pytest
@ -301,3 +302,36 @@ async def test_list_assets_invalid_query_rejected(http: aiohttp.ClientSession, a
b2 = await r2.json() b2 = await r2.json()
assert r2.status == 400 assert r2.status == 400
assert b2["error"]["code"] == "INVALID_QUERY" assert b2["error"]["code"] == "INVALID_QUERY"
@pytest.mark.asyncio
async def test_list_assets_name_contains_literal_underscore(
http,
api_base,
asset_factory,
make_asset_bytes,
):
"""'name_contains' must treat '_' literally, not as a SQL wildcard.
We create:
- foo_bar.safetensors (should match)
- fooxbar.safetensors (must NOT match if '_' is escaped)
- foobar.safetensors (must NOT match)
"""
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
tags = ["models", "checkpoints", "unit-tests", scope]
a = await asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
b = await asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))
c = await asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700))
async with http.get(
api_base + "/api/assets",
params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"},
) as r:
body = await r.json()
assert r.status == 200, body
names = [x["name"] for x in body["assets"]]
assert a["name"] in names, f"Expected literal underscore match to include {a['name']}"
assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'"
assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'"
assert body["total"] == 1