feat: remove Asset when there is no references left + bugfixes + more tests

This commit is contained in:
bigcat88 2025-09-08 22:37:39 +03:00
parent 0e9de2b7c9
commit dfb5703d40
No known key found for this signature in database
GPG Key ID: 1F0BF0EC3CF22721
7 changed files with 332 additions and 51 deletions

View File

@ -277,9 +277,13 @@ async def upload_asset(request: web.Request) -> web.Response:
os.remove(tmp_path) os.remove(tmp_path)
msg = str(e) msg = str(e)
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
return _error_response(400, "HASH_MISMATCH", "Uploaded file hash does not match provided hash.") return _error_response(
400,
"HASH_MISMATCH",
"Uploaded file hash does not match provided hash.",
)
return _error_response(400, "BAD_REQUEST", "Invalid inputs.") return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
except Exception: except Exception as e:
if tmp_path and os.path.exists(tmp_path): if tmp_path and os.path.exists(tmp_path):
os.remove(tmp_path) os.remove(tmp_path)
return _error_response(500, "INTERNAL", "Unexpected server error.") return _error_response(500, "INTERNAL", "Unexpected server error.")
@ -343,10 +347,14 @@ async def delete_asset(request: web.Request) -> web.Response:
except Exception: except Exception:
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.")
delete_content = request.query.get("delete_content")
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
try: try:
deleted = await assets_manager.delete_asset_reference( deleted = await assets_manager.delete_asset_reference(
asset_info_id=asset_info_id, asset_info_id=asset_info_id,
owner_id=UserManager.get_request_user_id(request), owner_id=UserManager.get_request_user_id(request),
delete_content_if_orphan=delete_content,
) )
except Exception: except Exception:
return _error_response(500, "INTERNAL", "Unexpected server error.") return _error_response(500, "INTERNAL", "Unexpected server error.")

View File

@ -27,6 +27,8 @@ from .database.services import (
create_asset_info_for_existing_asset, create_asset_info_for_existing_asset,
fetch_asset_info_asset_and_tags, fetch_asset_info_asset_and_tags,
get_asset_info_by_id, get_asset_info_by_id,
list_cache_states_by_asset_hash,
asset_info_exists_for_hash,
) )
from .api import schemas_in, schemas_out from .api import schemas_in, schemas_out
from ._assets_helpers import ( from ._assets_helpers import (
@ -371,11 +373,40 @@ async def update_asset(
) )
async def delete_asset_reference(*, asset_info_id: str, owner_id: str) -> bool: async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
"""Delete single AssetInfo. If this was the last reference to Asset and delete_content_if_orphan=True (default),
delete the Asset row as well and remove all cached files recorded for that asset_hash.
"""
async with await create_session() as session: async with await create_session() as session:
r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_hash = info_row.asset_hash if info_row else None
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
await session.commit()
return False
if not delete_content_if_orphan or not asset_hash:
await session.commit()
return True
still_exists = await asset_info_exists_for_hash(session, asset_hash=asset_hash)
if still_exists:
await session.commit()
return True
states = await list_cache_states_by_asset_hash(session, asset_hash=asset_hash)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = await get_asset_by_hash(session, asset_hash=asset_hash)
if asset_row is not None:
await session.delete(asset_row)
await session.commit() await session.commit()
return r for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
async def create_asset_from_hash( async def create_asset_from_hash(

View File

@ -36,6 +36,17 @@ async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) ->
return await session.get(AssetInfo, asset_info_id) return await session.get(AssetInfo, asset_info_id)
async def asset_info_exists_for_hash(session: AsyncSession, *, asset_hash: str) -> bool:
return (
await session.execute(
sa.select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_hash == asset_hash)
.limit(1)
)
).first() is not None
async def check_fs_asset_exists_quick( async def check_fs_asset_exists_quick(
session, session,
*, *,
@ -586,7 +597,7 @@ async def create_asset_info_for_existing_asset(
tag_origin: str = "manual", tag_origin: str = "manual",
owner_id: str = "", owner_id: str = "",
) -> AssetInfo: ) -> AssetInfo:
"""Create a new AssetInfo referencing an existing Asset (no content write).""" """Create a new AssetInfo referencing an existing Asset. If row already exists, return it unchanged."""
now = utcnow() now = utcnow()
info = AssetInfo( info = AssetInfo(
owner_id=owner_id, owner_id=owner_id,
@ -597,8 +608,25 @@ async def create_asset_info_for_existing_asset(
updated_at=now, updated_at=now,
last_access_time=now, last_access_time=now,
) )
session.add(info) try:
await session.flush() # get info.id async with session.begin_nested():
session.add(info)
await session.flush() # get info.id
except IntegrityError:
existing = (
await session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_hash == asset_hash,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore
# if user_metadata is not None: # if user_metadata is not None:

View File

@ -52,7 +52,6 @@ def comfy_tmp_base_dir() -> Path:
tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-")) tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-"))
_make_base_dirs(tmp) _make_base_dirs(tmp)
yield tmp yield tmp
# cleanup in a best-effort way; ComfyUI should not keep files open in this dir
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
for p in sorted(tmp.rglob("*"), reverse=True): for p in sorted(tmp.rglob("*"), reverse=True):
if p.is_file() or p.is_symlink(): if p.is_file() or p.is_symlink():
@ -72,10 +71,9 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path):
- autoscan disabled - autoscan disabled
Returns (base_url, process, port) Returns (base_url, process, port)
""" """
port = 8500 # _free_port() port = _free_port()
db_url = "sqlite+aiosqlite:///:memory:" db_url = "sqlite+aiosqlite:///:memory:"
# stdout/stderr capturing for debugging if something goes wrong
logs_dir = comfy_tmp_base_dir / "logs" logs_dir = comfy_tmp_base_dir / "logs"
logs_dir.mkdir(exist_ok=True) logs_dir.mkdir(exist_ok=True)
out_log = open(logs_dir / "stdout.log", "w", buffering=1) out_log = open(logs_dir / "stdout.log", "w", buffering=1)
@ -138,28 +136,59 @@ def api_base(comfy_url_and_proc) -> str:
return base_url return base_url
@pytest.fixture async def _post_multipart_asset(
def make_asset_bytes() -> Callable[[str], bytes]: session: aiohttp.ClientSession,
def _make(name: str) -> bytes: base: str,
# Generate deterministic small content variations based on name *,
seed = sum(ord(c) for c in name) % 251 name: str,
data = bytes((i * 31 + seed) % 256 for i in range(8192)) tags: list[str],
return data meta: dict,
return _make data: bytes,
extra_fields: dict | None = None,
) -> tuple[int, dict]:
async def _upload_asset(session: aiohttp.ClientSession, base: str, *, name: str, tags: list[str], meta: dict) -> dict:
make_asset_bytes = bytes((i % 251) for i in range(4096))
form = aiohttp.FormData() form = aiohttp.FormData()
form.add_field("file", make_asset_bytes, filename=name, content_type="application/octet-stream") form.add_field("file", data, filename=name, content_type="application/octet-stream")
form.add_field("tags", json.dumps(tags)) form.add_field("tags", json.dumps(tags))
form.add_field("name", name) form.add_field("name", name)
form.add_field("user_metadata", json.dumps(meta)) form.add_field("user_metadata", json.dumps(meta))
if extra_fields:
for k, v in extra_fields.items():
form.add_field(k, v)
async with session.post(base + "/api/assets", data=form) as r: async with session.post(base + "/api/assets", data=form) as r:
body = await r.json() body = await r.json()
assert r.status in (200, 201), body return r.status, body
@pytest.fixture
def make_asset_bytes() -> Callable[[str, int], bytes]:
def _make(name: str, size: int = 8192) -> bytes:
seed = sum(ord(c) for c in name) % 251
return bytes((i * 31 + seed) % 256 for i in range(size))
return _make
@pytest_asyncio.fixture
async def asset_factory(http: aiohttp.ClientSession, api_base: str):
"""
Returns create(name, tags, meta, data) -> response dict
Tracks created ids and deletes them after the test.
"""
created: list[str] = []
async def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict:
status, body = await _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data)
assert status in (200, 201), body
created.append(body["id"])
return body return body
yield create
# cleanup by id
for aid in created:
with contextlib.suppress(Exception):
async with http.delete(f"{api_base}/api/assets/{aid}") as r:
await r.read()
@pytest_asyncio.fixture @pytest_asyncio.fixture
async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict: async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict:
@ -179,3 +208,25 @@ async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict:
body = await r.json() body = await r.json()
assert r.status == 201, body assert r.status == 201, body
return body return body
@pytest_asyncio.fixture(autouse=True)
async def autoclean_unit_test_assets(http: aiohttp.ClientSession, api_base: str):
"""Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test."""
yield
while True:
async with http.get(
api_base + "/api/assets",
params={"include_tags": "unit-tests", "limit": "500", "sort": "name"},
) as r:
body = await r.json()
if r.status != 200:
break
ids = [a["id"] for a in body.get("assets", [])]
if not ids:
break
for aid in ids:
with contextlib.suppress(Exception):
async with http.delete(f"{api_base}/api/assets/{aid}") as dr:
await dr.read()

View File

@ -1,26 +0,0 @@
import aiohttp
import pytest
@pytest.mark.asyncio
async def test_tags_listing_endpoint(http: aiohttp.ClientSession, api_base: str):
# Include zero-usage tags by default
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
body1 = await r1.json()
assert r1.status == 200
names = [t["name"] for t in body1["tags"]]
# A few system tags from migration should exist:
assert "models" in names
assert "checkpoints" in names
# Only used tags
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
body2 = await r2.json()
assert r2.status == 200
# Should contain no tags
assert not [t["name"] for t in body2["tags"]]
# TODO-1: add some asset
# TODO-2: check that "used" tags are now non zero amount
# TODO-3: do a global teardown, so the state of ComfyUI is clear after each test, and all test can be run solo or one-by-one without any problems.

56
tests-assets/test_tags.py Normal file
View File

@ -0,0 +1,56 @@
import aiohttp
import pytest
@pytest.mark.asyncio
async def test_tags_present(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
# Include zero-usage tags by default
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
body1 = await r1.json()
assert r1.status == 200
names = [t["name"] for t in body1["tags"]]
# A few system tags from migration should exist:
assert "models" in names
assert "checkpoints" in names
# Only used tags before we add anything new from this test cycle
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
body2 = await r2.json()
assert r2.status == 200
# We already seeded one asset via fixture, so used tags must be non-empty
used_names = [t["name"] for t in body2["tags"]]
assert "models" in used_names
assert "checkpoints" in used_names
# Prefix filter should refine the list
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}) as r3:
b3 = await r3.json()
assert r3.status == 200
names3 = [t["name"] for t in b3["tags"]]
assert "unit-tests" in names3
assert "models" not in names3 # filtered out by prefix
# Order by name ascending should be stable
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}) as r4:
b4 = await r4.json()
assert r4.status == 200
names4 = [t["name"] for t in b4["tags"]]
assert names4 == sorted(names4)
@pytest.mark.asyncio
async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str):
# Include zero-usage tags by default
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
body1 = await r1.json()
assert r1.status == 200
names = [t["name"] for t in body1["tags"]]
# A few system tags from migration should exist:
assert "models" in names
assert "checkpoints" in names
# With include_zero=False there should be no tags returned for the database without Assets.
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
body2 = await r2.json()
assert r2.status == 200
assert not [t["name"] for t in body2["tags"]]

View File

@ -0,0 +1,133 @@
import json
import aiohttp
import pytest
@pytest.mark.asyncio
async def test_upload_requires_multipart(http: aiohttp.ClientSession, api_base: str):
async with http.post(api_base + "/api/assets", json={"foo": "bar"}) as r:
body = await r.json()
assert r.status == 415
assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE"
@pytest.mark.asyncio
async def test_upload_missing_file_and_hash(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData(default_to_multipart=True)
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests"]))
form.add_field("name", "x.safetensors")
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "MISSING_FILE"
@pytest.mark.asyncio
async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData()
form.add_field("file", b"A" * 128, filename="m.safetensors", content_type="application/octet-stream")
form.add_field("tags", json.dumps(["models", "no_such_category", "unit-tests"]))
form.add_field("name", "m.safetensors")
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] == "INVALID_BODY"
assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"]
@pytest.mark.asyncio
async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base: str):
form = aiohttp.FormData()
form.add_field("file", b"A" * 256, filename="evil.safetensors", content_type="application/octet-stream")
# '..' should be rejected by destination resolver
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]))
form.add_field("name", "evil.safetensors")
async with http.post(api_base + "/api/assets", data=form) as r:
body = await r.json()
assert r.status == 400
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
@pytest.mark.asyncio
async def test_upload_ok_duplicate_reference(http: aiohttp.ClientSession, api_base: str, make_asset_bytes):
name = "dup_a.safetensors"
tags = ["models", "checkpoints", "unit-tests", "alpha"]
meta = {"purpose": "dup"}
data = make_asset_bytes(name)
form1 = aiohttp.FormData()
form1.add_field("file", data, filename=name, content_type="application/octet-stream")
form1.add_field("tags", json.dumps(tags))
form1.add_field("name", name)
form1.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form1) as r1:
a1 = await r1.json()
assert r1.status == 201, a1
assert a1["created_new"] is True
# Second upload with the same data and name should return created_new == False and the same asset
form2 = aiohttp.FormData()
form2.add_field("file", data, filename=name, content_type="application/octet-stream")
form2.add_field("tags", json.dumps(tags))
form2.add_field("name", name)
form2.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form2) as r2:
a2 = await r2.json()
assert r2.status == 200, a2
assert a2["created_new"] is False
assert a2["asset_hash"] == a1["asset_hash"]
assert a2["id"] == a1["id"] # old reference
# Third upload with the same data but new name should return created_new == False and the new AssetReference
form3 = aiohttp.FormData()
form3.add_field("file", data, filename=name, content_type="application/octet-stream")
form3.add_field("tags", json.dumps(tags))
form3.add_field("name", name + "_d")
form3.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form3) as r2:
a3 = await r2.json()
assert r2.status == 200, a3
assert a3["created_new"] is False
assert a3["asset_hash"] == a1["asset_hash"]
assert a3["id"] != a1["id"] # old reference
@pytest.mark.asyncio
async def test_upload_fastpath_from_existing_hash_no_file(http: aiohttp.ClientSession, api_base: str):
# Seed a small file first
name = "fastpath_seed.safetensors"
tags = ["models", "checkpoints", "unit-tests"]
meta = {}
form1 = aiohttp.FormData()
form1.add_field("file", b"B" * 1024, filename=name, content_type="application/octet-stream")
form1.add_field("tags", json.dumps(tags))
form1.add_field("name", name)
form1.add_field("user_metadata", json.dumps(meta))
async with http.post(api_base + "/api/assets", data=form1) as r1:
b1 = await r1.json()
assert r1.status == 201, b1
h = b1["asset_hash"]
# Now POST /api/assets with only hash and no file
form2 = aiohttp.FormData()
form2.add_field("hash", h)
form2.add_field("tags", json.dumps(tags))
form2.add_field("name", "fastpath_copy.safetensors")
form2.add_field("user_metadata", json.dumps({"purpose": "copy"}))
async with http.post(api_base + "/api/assets", data=form2) as r2:
b2 = await r2.json()
assert r2.status == 200, b2 # fast path returns 200 with created_new == False
assert b2["created_new"] is False
assert b2["asset_hash"] == h
@pytest.mark.asyncio
async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str):
payload = {
"hash": "blake3:" + "0" * 64,
"name": "nonexistent.bin",
"tags": ["models", "checkpoints", "unit-tests"],
}
async with http.post(api_base + "/api/assets/from-hash", json=payload) as r:
body = await r.json()
assert r.status == 404
assert body["error"]["code"] == "ASSET_NOT_FOUND"