mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-21 03:50:50 +08:00
fixed validation error + more tests
This commit is contained in:
parent
faa1e4de17
commit
0ef73e95fd
@ -485,4 +485,4 @@ def _error_response(status: int, code: str, message: str, details: Optional[dict
|
|||||||
|
|
||||||
|
|
||||||
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
||||||
return _error_response(400, code, "Validation failed.", {"errors": ve.errors()})
|
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||||
|
|||||||
85
tests-assets/test_list_filter.py
Normal file
85
tests-assets/test_list_filter.py
Normal file
@ -0,0 +1,85 @@
|
|||||||
|
import json
|
||||||
|
import aiohttp
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_assets_paging_and_sort(http: aiohttp.ClientSession, api_base: str, asset_factory, make_asset_bytes):
|
||||||
|
names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"]
|
||||||
|
for n in names:
|
||||||
|
await asset_factory(
|
||||||
|
n,
|
||||||
|
["models", "checkpoints", "unit-tests", "paging"],
|
||||||
|
{"epoch": 1},
|
||||||
|
make_asset_bytes(n, size=2048),
|
||||||
|
)
|
||||||
|
|
||||||
|
# name ascending for stable order
|
||||||
|
async with http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"},
|
||||||
|
) as r1:
|
||||||
|
b1 = await r1.json()
|
||||||
|
assert r1.status == 200
|
||||||
|
got1 = [a["name"] for a in b1["assets"]]
|
||||||
|
assert got1 == sorted(names)[:2]
|
||||||
|
assert b1["has_more"] is True
|
||||||
|
|
||||||
|
async with http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"},
|
||||||
|
) as r2:
|
||||||
|
b2 = await r2.json()
|
||||||
|
assert r2.status == 200
|
||||||
|
got2 = [a["name"] for a in b2["assets"]]
|
||||||
|
assert got2 == sorted(names)[2:]
|
||||||
|
assert b2["has_more"] is False
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_assets_include_exclude_and_name_contains(http: aiohttp.ClientSession, api_base: str, asset_factory):
|
||||||
|
a = await asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
|
||||||
|
b = await asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
|
||||||
|
|
||||||
|
async with http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"},
|
||||||
|
) as r:
|
||||||
|
body = await r.json()
|
||||||
|
assert r.status == 200
|
||||||
|
names = [x["name"] for x in body["assets"]]
|
||||||
|
assert a["name"] in names
|
||||||
|
assert b["name"] not in names
|
||||||
|
|
||||||
|
async with http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests", "name_contains": "inc_"},
|
||||||
|
) as r2:
|
||||||
|
body2 = await r2.json()
|
||||||
|
assert r2.status == 200
|
||||||
|
names2 = [x["name"] for x in body2["assets"]]
|
||||||
|
assert a["name"] in names2
|
||||||
|
assert b["name"] in names2
|
||||||
|
|
||||||
|
async with http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "non-existing-tag"},
|
||||||
|
) as r2:
|
||||||
|
body3 = await r2.json()
|
||||||
|
assert r2.status == 200
|
||||||
|
assert not body3["assets"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_list_assets_invalid_query_rejected(http: aiohttp.ClientSession, api_base: str):
|
||||||
|
# limit too small
|
||||||
|
async with http.get(api_base + "/api/assets", params={"limit": "0"}) as r1:
|
||||||
|
b1 = await r1.json()
|
||||||
|
assert r1.status == 400
|
||||||
|
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||||
|
|
||||||
|
# bad metadata JSON
|
||||||
|
async with http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}) as r2:
|
||||||
|
b2 = await r2.json()
|
||||||
|
assert r2.status == 400
|
||||||
|
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||||
@ -54,3 +54,47 @@ async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str):
|
|||||||
body2 = await r2.json()
|
body2 = await r2.json()
|
||||||
assert r2.status == 200
|
assert r2.status == 200
|
||||||
assert not [t["name"] for t in body2["tags"]]
|
assert not [t["name"] for t in body2["tags"]]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_add_and_remove_tags(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||||
|
aid = seeded_asset["id"]
|
||||||
|
|
||||||
|
# Add tags with duplicates and mixed case
|
||||||
|
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
|
||||||
|
async with http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add) as r1:
|
||||||
|
b1 = await r1.json()
|
||||||
|
assert r1.status == 200, b1
|
||||||
|
# normalized and deduplicated
|
||||||
|
assert "newtag" in b1["added"] or "beta" in b1["added"] or "unit-tests" not in b1["added"]
|
||||||
|
|
||||||
|
async with http.get(f"{api_base}/api/assets/{aid}") as rg:
|
||||||
|
g = await rg.json()
|
||||||
|
assert rg.status == 200
|
||||||
|
tags_now = set(g["tags"])
|
||||||
|
assert "newtag" in tags_now
|
||||||
|
assert "beta" in tags_now
|
||||||
|
|
||||||
|
# Remove a tag and a non-existent tag
|
||||||
|
payload_del = {"tags": ["newtag", "does-not-exist"]}
|
||||||
|
async with http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del) as r2:
|
||||||
|
b2 = await r2.json()
|
||||||
|
assert r2.status == 200
|
||||||
|
assert "newtag" in b2["removed"]
|
||||||
|
assert "does-not-exist" in b2["not_present"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tags_list_order_and_prefix(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||||
|
# name ascending
|
||||||
|
async with http.get(api_base + "/api/tags", params={"order": "name_asc", "limit": "100"}) as r1:
|
||||||
|
b1 = await r1.json()
|
||||||
|
assert r1.status == 200
|
||||||
|
names = [t["name"] for t in b1["tags"]]
|
||||||
|
assert names == sorted(names)
|
||||||
|
|
||||||
|
# invalid limit rejected
|
||||||
|
async with http.get(api_base + "/api/tags", params={"limit": "1001"}) as r2:
|
||||||
|
b2 = await r2.json()
|
||||||
|
assert r2.status == 400
|
||||||
|
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user