diff --git a/app/api/assets_routes.py b/app/api/assets_routes.py index 71e99f231..248f7a2f9 100644 --- a/app/api/assets_routes.py +++ b/app/api/assets_routes.py @@ -277,9 +277,13 @@ async def upload_asset(request: web.Request) -> web.Response: os.remove(tmp_path) msg = str(e) if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": - return _error_response(400, "HASH_MISMATCH", "Uploaded file hash does not match provided hash.") + return _error_response( + 400, + "HASH_MISMATCH", + "Uploaded file hash does not match provided hash.", + ) return _error_response(400, "BAD_REQUEST", "Invalid inputs.") - except Exception: + except Exception as e: if tmp_path and os.path.exists(tmp_path): os.remove(tmp_path) return _error_response(500, "INTERNAL", "Unexpected server error.") @@ -343,10 +347,14 @@ async def delete_asset(request: web.Request) -> web.Response: except Exception: return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.") + delete_content = request.query.get("delete_content") + delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} + try: deleted = await assets_manager.delete_asset_reference( asset_info_id=asset_info_id, owner_id=UserManager.get_request_user_id(request), + delete_content_if_orphan=delete_content, ) except Exception: return _error_response(500, "INTERNAL", "Unexpected server error.") diff --git a/app/assets_manager.py b/app/assets_manager.py index b84b61508..a2a73773a 100644 --- a/app/assets_manager.py +++ b/app/assets_manager.py @@ -27,6 +27,8 @@ from .database.services import ( create_asset_info_for_existing_asset, fetch_asset_info_asset_and_tags, get_asset_info_by_id, + list_cache_states_by_asset_hash, + asset_info_exists_for_hash, ) from .api import schemas_in, schemas_out from ._assets_helpers import ( @@ -371,11 +373,40 @@ async def update_asset( ) -async def delete_asset_reference(*, asset_info_id: str, owner_id: str) -> bool: +async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: + """Delete single AssetInfo. If this was the last reference to Asset and delete_content_if_orphan=True (default), + delete the Asset row as well and remove all cached files recorded for that asset_hash. + """ async with await create_session() as session: - r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) + info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id) + asset_hash = info_row.asset_hash if info_row else None + deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not deleted: + await session.commit() + return False + + if not delete_content_if_orphan or not asset_hash: + await session.commit() + return True + + still_exists = await asset_info_exists_for_hash(session, asset_hash=asset_hash) + if still_exists: + await session.commit() + return True + + states = await list_cache_states_by_asset_hash(session, asset_hash=asset_hash) + file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] + + asset_row = await get_asset_by_hash(session, asset_hash=asset_hash) + if asset_row is not None: + await session.delete(asset_row) + await session.commit() - return r + for p in file_paths: + with contextlib.suppress(Exception): + if p and os.path.isfile(p): + os.remove(p) + return True async def create_asset_from_hash( diff --git a/app/database/services.py b/app/database/services.py index 42f647d91..842103e9e 100644 --- a/app/database/services.py +++ b/app/database/services.py @@ -36,6 +36,17 @@ async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) -> return await session.get(AssetInfo, asset_info_id) +async def asset_info_exists_for_hash(session: AsyncSession, *, asset_hash: str) -> bool: + return ( + await session.execute( + sa.select(sa.literal(True)) + .select_from(AssetInfo) + .where(AssetInfo.asset_hash == asset_hash) + .limit(1) + ) + ).first() is not None + + async def check_fs_asset_exists_quick( session, *, @@ -586,7 +597,7 @@ async def create_asset_info_for_existing_asset( tag_origin: str = "manual", owner_id: str = "", ) -> AssetInfo: - """Create a new AssetInfo referencing an existing Asset (no content write).""" + """Create a new AssetInfo referencing an existing Asset. If row already exists, return it unchanged.""" now = utcnow() info = AssetInfo( owner_id=owner_id, @@ -597,8 +608,25 @@ async def create_asset_info_for_existing_asset( updated_at=now, last_access_time=now, ) - session.add(info) - await session.flush() # get info.id + try: + async with session.begin_nested(): + session.add(info) + await session.flush() # get info.id + except IntegrityError: + existing = ( + await session.execute( + select(AssetInfo) + .where( + AssetInfo.asset_hash == asset_hash, + AssetInfo.name == name, + AssetInfo.owner_id == owner_id, + ) + .limit(1) + ) + ).scalars().first() + if not existing: + raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") + return existing # Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore # if user_metadata is not None: diff --git a/tests-assets/conftest.py b/tests-assets/conftest.py index 82d02dc74..24eed0728 100644 --- a/tests-assets/conftest.py +++ b/tests-assets/conftest.py @@ -52,7 +52,6 @@ def comfy_tmp_base_dir() -> Path: tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-")) _make_base_dirs(tmp) yield tmp - # cleanup in a best-effort way; ComfyUI should not keep files open in this dir with contextlib.suppress(Exception): for p in sorted(tmp.rglob("*"), reverse=True): if p.is_file() or p.is_symlink(): @@ -72,10 +71,9 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path): - autoscan disabled Returns (base_url, process, port) """ - port = 8500 # _free_port() + port = _free_port() db_url = "sqlite+aiosqlite:///:memory:" - # stdout/stderr capturing for debugging if something goes wrong logs_dir = comfy_tmp_base_dir / "logs" logs_dir.mkdir(exist_ok=True) out_log = open(logs_dir / "stdout.log", "w", buffering=1) @@ -138,28 +136,59 @@ def api_base(comfy_url_and_proc) -> str: return base_url -@pytest.fixture -def make_asset_bytes() -> Callable[[str], bytes]: - def _make(name: str) -> bytes: - # Generate deterministic small content variations based on name - seed = sum(ord(c) for c in name) % 251 - data = bytes((i * 31 + seed) % 256 for i in range(8192)) - return data - return _make - - -async def _upload_asset(session: aiohttp.ClientSession, base: str, *, name: str, tags: list[str], meta: dict) -> dict: - make_asset_bytes = bytes((i % 251) for i in range(4096)) +async def _post_multipart_asset( + session: aiohttp.ClientSession, + base: str, + *, + name: str, + tags: list[str], + meta: dict, + data: bytes, + extra_fields: dict | None = None, +) -> tuple[int, dict]: form = aiohttp.FormData() - form.add_field("file", make_asset_bytes, filename=name, content_type="application/octet-stream") + form.add_field("file", data, filename=name, content_type="application/octet-stream") form.add_field("tags", json.dumps(tags)) form.add_field("name", name) form.add_field("user_metadata", json.dumps(meta)) + if extra_fields: + for k, v in extra_fields.items(): + form.add_field(k, v) async with session.post(base + "/api/assets", data=form) as r: body = await r.json() - assert r.status in (200, 201), body + return r.status, body + + +@pytest.fixture +def make_asset_bytes() -> Callable[[str, int], bytes]: + def _make(name: str, size: int = 8192) -> bytes: + seed = sum(ord(c) for c in name) % 251 + return bytes((i * 31 + seed) % 256 for i in range(size)) + return _make + + +@pytest_asyncio.fixture +async def asset_factory(http: aiohttp.ClientSession, api_base: str): + """ + Returns create(name, tags, meta, data) -> response dict + Tracks created ids and deletes them after the test. + """ + created: list[str] = [] + + async def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict: + status, body = await _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data) + assert status in (200, 201), body + created.append(body["id"]) return body + yield create + + # cleanup by id + for aid in created: + with contextlib.suppress(Exception): + async with http.delete(f"{api_base}/api/assets/{aid}") as r: + await r.read() + @pytest_asyncio.fixture async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict: @@ -179,3 +208,25 @@ async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict: body = await r.json() assert r.status == 201, body return body + + +@pytest_asyncio.fixture(autouse=True) +async def autoclean_unit_test_assets(http: aiohttp.ClientSession, api_base: str): + """Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test.""" + yield + + while True: + async with http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests", "limit": "500", "sort": "name"}, + ) as r: + body = await r.json() + if r.status != 200: + break + ids = [a["id"] for a in body.get("assets", [])] + if not ids: + break + for aid in ids: + with contextlib.suppress(Exception): + async with http.delete(f"{api_base}/api/assets/{aid}") as dr: + await dr.read() diff --git a/tests-assets/test_assets.py b/tests-assets/test_assets.py deleted file mode 100644 index dfcedc52c..000000000 --- a/tests-assets/test_assets.py +++ /dev/null @@ -1,26 +0,0 @@ -import aiohttp -import pytest - - -@pytest.mark.asyncio -async def test_tags_listing_endpoint(http: aiohttp.ClientSession, api_base: str): - # Include zero-usage tags by default - async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1: - body1 = await r1.json() - assert r1.status == 200 - names = [t["name"] for t in body1["tags"]] - # A few system tags from migration should exist: - assert "models" in names - assert "checkpoints" in names - - # Only used tags - async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2: - body2 = await r2.json() - assert r2.status == 200 - # Should contain no tags - assert not [t["name"] for t in body2["tags"]] - - # TODO-1: add some asset - # TODO-2: check that "used" tags are now non zero amount - - # TODO-3: do a global teardown, so the state of ComfyUI is clear after each test, and all test can be run solo or one-by-one without any problems. diff --git a/tests-assets/test_tags.py b/tests-assets/test_tags.py new file mode 100644 index 000000000..c63df48bc --- /dev/null +++ b/tests-assets/test_tags.py @@ -0,0 +1,56 @@ +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_tags_present(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict): + # Include zero-usage tags by default + async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1: + body1 = await r1.json() + assert r1.status == 200 + names = [t["name"] for t in body1["tags"]] + # A few system tags from migration should exist: + assert "models" in names + assert "checkpoints" in names + + # Only used tags before we add anything new from this test cycle + async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2: + body2 = await r2.json() + assert r2.status == 200 + # We already seeded one asset via fixture, so used tags must be non-empty + used_names = [t["name"] for t in body2["tags"]] + assert "models" in used_names + assert "checkpoints" in used_names + + # Prefix filter should refine the list + async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}) as r3: + b3 = await r3.json() + assert r3.status == 200 + names3 = [t["name"] for t in b3["tags"]] + assert "unit-tests" in names3 + assert "models" not in names3 # filtered out by prefix + + # Order by name ascending should be stable + async with http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}) as r4: + b4 = await r4.json() + assert r4.status == 200 + names4 = [t["name"] for t in b4["tags"]] + assert names4 == sorted(names4) + + +@pytest.mark.asyncio +async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str): + # Include zero-usage tags by default + async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1: + body1 = await r1.json() + assert r1.status == 200 + names = [t["name"] for t in body1["tags"]] + # A few system tags from migration should exist: + assert "models" in names + assert "checkpoints" in names + + # With include_zero=False there should be no tags returned for the database without Assets. + async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2: + body2 = await r2.json() + assert r2.status == 200 + assert not [t["name"] for t in body2["tags"]] diff --git a/tests-assets/test_uploads.py b/tests-assets/test_uploads.py new file mode 100644 index 000000000..65c34f139 --- /dev/null +++ b/tests-assets/test_uploads.py @@ -0,0 +1,133 @@ +import json +import aiohttp +import pytest + + +@pytest.mark.asyncio +async def test_upload_requires_multipart(http: aiohttp.ClientSession, api_base: str): + async with http.post(api_base + "/api/assets", json={"foo": "bar"}) as r: + body = await r.json() + assert r.status == 415 + assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE" + + +@pytest.mark.asyncio +async def test_upload_missing_file_and_hash(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData(default_to_multipart=True) + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests"])) + form.add_field("name", "x.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "MISSING_FILE" + + +@pytest.mark.asyncio +async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 128, filename="m.safetensors", content_type="application/octet-stream") + form.add_field("tags", json.dumps(["models", "no_such_category", "unit-tests"])) + form.add_field("name", "m.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] == "INVALID_BODY" + assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"] + + +@pytest.mark.asyncio +async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base: str): + form = aiohttp.FormData() + form.add_field("file", b"A" * 256, filename="evil.safetensors", content_type="application/octet-stream") + # '..' should be rejected by destination resolver + form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"])) + form.add_field("name", "evil.safetensors") + async with http.post(api_base + "/api/assets", data=form) as r: + body = await r.json() + assert r.status == 400 + assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") + + +@pytest.mark.asyncio +async def test_upload_ok_duplicate_reference(http: aiohttp.ClientSession, api_base: str, make_asset_bytes): + name = "dup_a.safetensors" + tags = ["models", "checkpoints", "unit-tests", "alpha"] + meta = {"purpose": "dup"} + data = make_asset_bytes(name) + form1 = aiohttp.FormData() + form1.add_field("file", data, filename=name, content_type="application/octet-stream") + form1.add_field("tags", json.dumps(tags)) + form1.add_field("name", name) + form1.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form1) as r1: + a1 = await r1.json() + assert r1.status == 201, a1 + assert a1["created_new"] is True + + # Second upload with the same data and name should return created_new == False and the same asset + form2 = aiohttp.FormData() + form2.add_field("file", data, filename=name, content_type="application/octet-stream") + form2.add_field("tags", json.dumps(tags)) + form2.add_field("name", name) + form2.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form2) as r2: + a2 = await r2.json() + assert r2.status == 200, a2 + assert a2["created_new"] is False + assert a2["asset_hash"] == a1["asset_hash"] + assert a2["id"] == a1["id"] # old reference + + # Third upload with the same data but new name should return created_new == False and the new AssetReference + form3 = aiohttp.FormData() + form3.add_field("file", data, filename=name, content_type="application/octet-stream") + form3.add_field("tags", json.dumps(tags)) + form3.add_field("name", name + "_d") + form3.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form3) as r2: + a3 = await r2.json() + assert r2.status == 200, a3 + assert a3["created_new"] is False + assert a3["asset_hash"] == a1["asset_hash"] + assert a3["id"] != a1["id"] # old reference + + +@pytest.mark.asyncio +async def test_upload_fastpath_from_existing_hash_no_file(http: aiohttp.ClientSession, api_base: str): + # Seed a small file first + name = "fastpath_seed.safetensors" + tags = ["models", "checkpoints", "unit-tests"] + meta = {} + form1 = aiohttp.FormData() + form1.add_field("file", b"B" * 1024, filename=name, content_type="application/octet-stream") + form1.add_field("tags", json.dumps(tags)) + form1.add_field("name", name) + form1.add_field("user_metadata", json.dumps(meta)) + async with http.post(api_base + "/api/assets", data=form1) as r1: + b1 = await r1.json() + assert r1.status == 201, b1 + h = b1["asset_hash"] + + # Now POST /api/assets with only hash and no file + form2 = aiohttp.FormData() + form2.add_field("hash", h) + form2.add_field("tags", json.dumps(tags)) + form2.add_field("name", "fastpath_copy.safetensors") + form2.add_field("user_metadata", json.dumps({"purpose": "copy"})) + async with http.post(api_base + "/api/assets", data=form2) as r2: + b2 = await r2.json() + assert r2.status == 200, b2 # fast path returns 200 with created_new == False + assert b2["created_new"] is False + assert b2["asset_hash"] == h + + +@pytest.mark.asyncio +async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str): + payload = { + "hash": "blake3:" + "0" * 64, + "name": "nonexistent.bin", + "tags": ["models", "checkpoints", "unit-tests"], + } + async with http.post(api_base + "/api/assets/from-hash", json=payload) as r: + body = await r.json() + assert r.status == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND"