add more list_assets tests + fix one found bug

This commit is contained in:
bigcat88 2025-09-09 20:35:18 +03:00
parent 1886f10e19
commit 964de8a8ad
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
2 changed files with 229 additions and 2 deletions

View File

@ -34,7 +34,16 @@ async def head_asset_by_hash(request: web.Request) -> web.Response:
@ROUTES.get("/api/assets")
async def list_assets(request: web.Request) -> web.Response:
query_dict = dict(request.rel_url.query)
qp = request.rel_url.query
query_dict = {}
if "include_tags" in qp:
query_dict["include_tags"] = qp.getall("include_tags")
if "exclude_tags" in qp:
query_dict["exclude_tags"] = qp.getall("exclude_tags")
for k in ("name_contains", "metadata_filter", "limit", "offset", "sort", "order"):
v = qp.get(k)
if v is not None:
query_dict[k] = v
try:
q = schemas_in.ListAssetsQuery.model_validate(query_dict)

View File

@ -1,4 +1,5 @@
import json
import asyncio
import aiohttp
import pytest
@ -70,6 +71,223 @@ async def test_list_assets_include_exclude_and_name_contains(http: aiohttp.Clien
assert not body3["assets"]
@pytest.mark.asyncio
async def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-size"]
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
await asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
await asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
await asset_factory(n3, t, {}, make_asset_bytes(n3, 3072))
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"},
) as r1:
b1 = await r1.json()
names = [a["name"] for a in b1["assets"]]
assert names[:3] == [n1, n2, n3]
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"},
) as r2:
b2 = await r2.json()
names2 = [a["name"] for a in b2["assets"]]
assert names2[:3] == [n3, n2, n1]
@pytest.mark.asyncio
async def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-upd"]
a1 = await asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
a2 = await asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
# Rename the second asset to bump updated_at
async with http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}) as rp:
upd = await rp.json()
assert rp.status == 200, upd
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"},
) as r:
body = await r.json()
assert r.status == 200
names = [x["name"] for x in body["assets"]]
assert names[0] == "upd_b_renamed.safetensors"
assert a1["name"] in names
@pytest.mark.asyncio
async def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-access"]
await asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
await asyncio.sleep(0.02)
a2 = await asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
# Touch last_access_time of b by downloading its content
await asyncio.sleep(0.02)
async with http.get(f"{api_base}/api/assets/{a2['id']}/content") as dl:
assert dl.status == 200
await dl.read()
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"},
) as r:
body = await r.json()
assert r.status == 200
names = [x["name"] for x in body["assets"]]
assert names[0] == a2["name"]
@pytest.mark.asyncio
async def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-include"]
a = await asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
await asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
# CSV + case-insensitive
async with http.get(
api_base + "/api/assets",
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"},
) as r1:
b1 = await r1.json()
assert r1.status == 200
names1 = [x["name"] for x in b1["assets"]]
assert a["name"] in names1
assert not any("beta" in x for x in names1)
# Repeated query params for include_tags
params_multi = [
("include_tags", "unit-tests"),
("include_tags", "lf-include"),
("include_tags", "alpha"),
]
async with http.get(api_base + "/api/assets", params=params_multi) as r2:
b2 = await r2.json()
assert r2.status == 200
names2 = [x["name"] for x in b2["assets"]]
assert a["name"] in names2
assert not any("beta" in x for x in names2)
# Duplicates and spaces in CSV
async with http.get(
api_base + "/api/assets",
params={"include_tags": " unit-tests , lf-include , alpha , alpha "},
) as r3:
b3 = await r3.json()
assert r3.status == 200
names3 = [x["name"] for x in b3["assets"]]
assert a["name"] in names3
@pytest.mark.asyncio
async def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-exclude"]
a = await asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
await asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
# Exclude uppercase should work
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"},
) as r1:
b1 = await r1.json()
assert r1.status == 200
names1 = [x["name"] for x in b1["assets"]]
assert a["name"] in names1
# Repeated excludes with duplicates
params_multi = [
("include_tags", "unit-tests"),
("include_tags", "lf-exclude"),
("exclude_tags", "beta"),
("exclude_tags", "beta"),
]
async with http.get(api_base + "/api/assets", params=params_multi) as r2:
b2 = await r2.json()
assert r2.status == 200
names2 = [x["name"] for x in b2["assets"]]
assert all("beta" not in x for x in names2)
@pytest.mark.asyncio
async def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-name"]
a1 = await asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
a2 = await asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"},
) as r1:
b1 = await r1.json()
assert r1.status == 200
names1 = [x["name"] for x in b1["assets"]]
assert a1["name"] in names1
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"},
) as r2:
b2 = await r2.json()
assert r2.status == 200
names2 = [x["name"] for x in b2["assets"]]
assert a1["name"] in names2
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"},
) as r3:
b3 = await r3.json()
assert r3.status == 200
names3 = [x["name"] for x in b3["assets"]]
assert a2["name"] in names3
@pytest.mark.asyncio
async def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"]
await asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
await asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
await asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
# Offset far beyond total
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"},
) as r1:
b1 = await r1.json()
assert r1.status == 200
assert not b1["assets"]
assert b1["has_more"] is False
# Boundary large limit (<=500 is valid)
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"},
) as r2:
b2 = await r2.json()
assert r2.status == 200
assert len(b2["assets"]) == 3
assert b2["has_more"] is False
@pytest.mark.asyncio
async def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
async with http.get(api_base + "/api/assets", params={"offset": "-1"}) as r1:
b1 = await r1.json()
assert r1.status == 400
assert b1["error"]["code"] == "INVALID_QUERY"
async with http.get(api_base + "/api/assets", params={"limit": "abc"}) as r2:
b2 = await r2.json()
assert r2.status == 400
assert b2["error"]["code"] == "INVALID_QUERY"
@pytest.mark.asyncio
async def test_list_assets_invalid_query_rejected(http: aiohttp.ClientSession, api_base: str):
# limit too small