mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 02:23:06 +08:00
feat: remove Asset when there is no references left + bugfixes + more tests
This commit is contained in:
parent
0e9de2b7c9
commit
dfb5703d40
@ -277,9 +277,13 @@ async def upload_asset(request: web.Request) -> web.Response:
|
||||
os.remove(tmp_path)
|
||||
msg = str(e)
|
||||
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||
return _error_response(400, "HASH_MISMATCH", "Uploaded file hash does not match provided hash.")
|
||||
return _error_response(
|
||||
400,
|
||||
"HASH_MISMATCH",
|
||||
"Uploaded file hash does not match provided hash.",
|
||||
)
|
||||
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
if tmp_path and os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
@ -343,10 +347,14 @@ async def delete_asset(request: web.Request) -> web.Response:
|
||||
except Exception:
|
||||
return _error_response(400, "INVALID_ID", f"AssetInfo id '{asset_info_id_raw}' is not a valid UUID.")
|
||||
|
||||
delete_content = request.query.get("delete_content")
|
||||
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||
|
||||
try:
|
||||
deleted = await assets_manager.delete_asset_reference(
|
||||
asset_info_id=asset_info_id,
|
||||
owner_id=UserManager.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
)
|
||||
except Exception:
|
||||
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||
|
||||
@ -27,6 +27,8 @@ from .database.services import (
|
||||
create_asset_info_for_existing_asset,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
get_asset_info_by_id,
|
||||
list_cache_states_by_asset_hash,
|
||||
asset_info_exists_for_hash,
|
||||
)
|
||||
from .api import schemas_in, schemas_out
|
||||
from ._assets_helpers import (
|
||||
@ -371,11 +373,40 @@ async def update_asset(
|
||||
)
|
||||
|
||||
|
||||
async def delete_asset_reference(*, asset_info_id: str, owner_id: str) -> bool:
|
||||
async def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
"""Delete single AssetInfo. If this was the last reference to Asset and delete_content_if_orphan=True (default),
|
||||
delete the Asset row as well and remove all cached files recorded for that asset_hash.
|
||||
"""
|
||||
async with await create_session() as session:
|
||||
r = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
info_row = await get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_hash = info_row.asset_hash if info_row else None
|
||||
deleted = await delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
await session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_hash:
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
still_exists = await asset_info_exists_for_hash(session, asset_hash=asset_hash)
|
||||
if still_exists:
|
||||
await session.commit()
|
||||
return True
|
||||
|
||||
states = await list_cache_states_by_asset_hash(session, asset_hash=asset_hash)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = await get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if asset_row is not None:
|
||||
await session.delete(asset_row)
|
||||
|
||||
await session.commit()
|
||||
return r
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
async def create_asset_from_hash(
|
||||
|
||||
@ -36,6 +36,17 @@ async def get_asset_info_by_id(session: AsyncSession, *, asset_info_id: str) ->
|
||||
return await session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
async def asset_info_exists_for_hash(session: AsyncSession, *, asset_hash: str) -> bool:
|
||||
return (
|
||||
await session.execute(
|
||||
sa.select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_hash == asset_hash)
|
||||
.limit(1)
|
||||
)
|
||||
).first() is not None
|
||||
|
||||
|
||||
async def check_fs_asset_exists_quick(
|
||||
session,
|
||||
*,
|
||||
@ -586,7 +597,7 @@ async def create_asset_info_for_existing_asset(
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create a new AssetInfo referencing an existing Asset (no content write)."""
|
||||
"""Create a new AssetInfo referencing an existing Asset. If row already exists, return it unchanged."""
|
||||
now = utcnow()
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
@ -597,8 +608,25 @@ async def create_asset_info_for_existing_asset(
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
await session.flush() # get info.id
|
||||
try:
|
||||
async with session.begin_nested():
|
||||
session.add(info)
|
||||
await session.flush() # get info.id
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
await session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_hash == asset_hash,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# Uncomment next code, and remove code after it, once the hack with "metadata[filename" is not needed anymore
|
||||
# if user_metadata is not None:
|
||||
|
||||
@ -52,7 +52,6 @@ def comfy_tmp_base_dir() -> Path:
|
||||
tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-"))
|
||||
_make_base_dirs(tmp)
|
||||
yield tmp
|
||||
# cleanup in a best-effort way; ComfyUI should not keep files open in this dir
|
||||
with contextlib.suppress(Exception):
|
||||
for p in sorted(tmp.rglob("*"), reverse=True):
|
||||
if p.is_file() or p.is_symlink():
|
||||
@ -72,10 +71,9 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path):
|
||||
- autoscan disabled
|
||||
Returns (base_url, process, port)
|
||||
"""
|
||||
port = 8500 # _free_port()
|
||||
port = _free_port()
|
||||
db_url = "sqlite+aiosqlite:///:memory:"
|
||||
|
||||
# stdout/stderr capturing for debugging if something goes wrong
|
||||
logs_dir = comfy_tmp_base_dir / "logs"
|
||||
logs_dir.mkdir(exist_ok=True)
|
||||
out_log = open(logs_dir / "stdout.log", "w", buffering=1)
|
||||
@ -138,28 +136,59 @@ def api_base(comfy_url_and_proc) -> str:
|
||||
return base_url
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_asset_bytes() -> Callable[[str], bytes]:
|
||||
def _make(name: str) -> bytes:
|
||||
# Generate deterministic small content variations based on name
|
||||
seed = sum(ord(c) for c in name) % 251
|
||||
data = bytes((i * 31 + seed) % 256 for i in range(8192))
|
||||
return data
|
||||
return _make
|
||||
|
||||
|
||||
async def _upload_asset(session: aiohttp.ClientSession, base: str, *, name: str, tags: list[str], meta: dict) -> dict:
|
||||
make_asset_bytes = bytes((i % 251) for i in range(4096))
|
||||
async def _post_multipart_asset(
|
||||
session: aiohttp.ClientSession,
|
||||
base: str,
|
||||
*,
|
||||
name: str,
|
||||
tags: list[str],
|
||||
meta: dict,
|
||||
data: bytes,
|
||||
extra_fields: dict | None = None,
|
||||
) -> tuple[int, dict]:
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", make_asset_bytes, filename=name, content_type="application/octet-stream")
|
||||
form.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(tags))
|
||||
form.add_field("name", name)
|
||||
form.add_field("user_metadata", json.dumps(meta))
|
||||
if extra_fields:
|
||||
for k, v in extra_fields.items():
|
||||
form.add_field(k, v)
|
||||
async with session.post(base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status in (200, 201), body
|
||||
return r.status, body
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def make_asset_bytes() -> Callable[[str, int], bytes]:
|
||||
def _make(name: str, size: int = 8192) -> bytes:
|
||||
seed = sum(ord(c) for c in name) % 251
|
||||
return bytes((i * 31 + seed) % 256 for i in range(size))
|
||||
return _make
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def asset_factory(http: aiohttp.ClientSession, api_base: str):
|
||||
"""
|
||||
Returns create(name, tags, meta, data) -> response dict
|
||||
Tracks created ids and deletes them after the test.
|
||||
"""
|
||||
created: list[str] = []
|
||||
|
||||
async def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict:
|
||||
status, body = await _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data)
|
||||
assert status in (200, 201), body
|
||||
created.append(body["id"])
|
||||
return body
|
||||
|
||||
yield create
|
||||
|
||||
# cleanup by id
|
||||
for aid in created:
|
||||
with contextlib.suppress(Exception):
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}") as r:
|
||||
await r.read()
|
||||
|
||||
|
||||
@pytest_asyncio.fixture
|
||||
async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict:
|
||||
@ -179,3 +208,25 @@ async def seeded_asset(http: aiohttp.ClientSession, api_base: str) -> dict:
|
||||
body = await r.json()
|
||||
assert r.status == 201, body
|
||||
return body
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(autouse=True)
|
||||
async def autoclean_unit_test_assets(http: aiohttp.ClientSession, api_base: str):
|
||||
"""Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test."""
|
||||
yield
|
||||
|
||||
while True:
|
||||
async with http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests", "limit": "500", "sort": "name"},
|
||||
) as r:
|
||||
body = await r.json()
|
||||
if r.status != 200:
|
||||
break
|
||||
ids = [a["id"] for a in body.get("assets", [])]
|
||||
if not ids:
|
||||
break
|
||||
for aid in ids:
|
||||
with contextlib.suppress(Exception):
|
||||
async with http.delete(f"{api_base}/api/assets/{aid}") as dr:
|
||||
await dr.read()
|
||||
|
||||
@ -1,26 +0,0 @@
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_listing_endpoint(http: aiohttp.ClientSession, api_base: str):
|
||||
# Include zero-usage tags by default
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
|
||||
body1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
# A few system tags from migration should exist:
|
||||
assert "models" in names
|
||||
assert "checkpoints" in names
|
||||
|
||||
# Only used tags
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
|
||||
body2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
# Should contain no tags
|
||||
assert not [t["name"] for t in body2["tags"]]
|
||||
|
||||
# TODO-1: add some asset
|
||||
# TODO-2: check that "used" tags are now non zero amount
|
||||
|
||||
# TODO-3: do a global teardown, so the state of ComfyUI is clear after each test, and all test can be run solo or one-by-one without any problems.
|
||||
56
tests-assets/test_tags.py
Normal file
56
tests-assets/test_tags.py
Normal file
@ -0,0 +1,56 @@
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_present(http: aiohttp.ClientSession, api_base: str, seeded_asset: dict):
|
||||
# Include zero-usage tags by default
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
|
||||
body1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
# A few system tags from migration should exist:
|
||||
assert "models" in names
|
||||
assert "checkpoints" in names
|
||||
|
||||
# Only used tags before we add anything new from this test cycle
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
|
||||
body2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
# We already seeded one asset via fixture, so used tags must be non-empty
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert "models" in used_names
|
||||
assert "checkpoints" in used_names
|
||||
|
||||
# Prefix filter should refine the list
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}) as r3:
|
||||
b3 = await r3.json()
|
||||
assert r3.status == 200
|
||||
names3 = [t["name"] for t in b3["tags"]]
|
||||
assert "unit-tests" in names3
|
||||
assert "models" not in names3 # filtered out by prefix
|
||||
|
||||
# Order by name ascending should be stable
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}) as r4:
|
||||
b4 = await r4.json()
|
||||
assert r4.status == 200
|
||||
names4 = [t["name"] for t in b4["tags"]]
|
||||
assert names4 == sorted(names4)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_tags_empty_usage(http: aiohttp.ClientSession, api_base: str):
|
||||
# Include zero-usage tags by default
|
||||
async with http.get(api_base + "/api/tags", params={"limit": "50"}) as r1:
|
||||
body1 = await r1.json()
|
||||
assert r1.status == 200
|
||||
names = [t["name"] for t in body1["tags"]]
|
||||
# A few system tags from migration should exist:
|
||||
assert "models" in names
|
||||
assert "checkpoints" in names
|
||||
|
||||
# With include_zero=False there should be no tags returned for the database without Assets.
|
||||
async with http.get(api_base + "/api/tags", params={"include_zero": "false"}) as r2:
|
||||
body2 = await r2.json()
|
||||
assert r2.status == 200
|
||||
assert not [t["name"] for t in body2["tags"]]
|
||||
133
tests-assets/test_uploads.py
Normal file
133
tests-assets/test_uploads.py
Normal file
@ -0,0 +1,133 @@
|
||||
import json
|
||||
import aiohttp
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_requires_multipart(http: aiohttp.ClientSession, api_base: str):
|
||||
async with http.post(api_base + "/api/assets", json={"foo": "bar"}) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 415
|
||||
assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_missing_file_and_hash(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData(default_to_multipart=True)
|
||||
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests"]))
|
||||
form.add_field("name", "x.safetensors")
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "MISSING_FILE"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_models_unknown_category(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 128, filename="m.safetensors", content_type="application/octet-stream")
|
||||
form.add_field("tags", json.dumps(["models", "no_such_category", "unit-tests"]))
|
||||
form.add_field("name", "m.safetensors")
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] == "INVALID_BODY"
|
||||
assert "unknown models category" in body["error"]["message"] or "unknown model category" in body["error"]["message"]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_tags_traversal_guard(http: aiohttp.ClientSession, api_base: str):
|
||||
form = aiohttp.FormData()
|
||||
form.add_field("file", b"A" * 256, filename="evil.safetensors", content_type="application/octet-stream")
|
||||
# '..' should be rejected by destination resolver
|
||||
form.add_field("tags", json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]))
|
||||
form.add_field("name", "evil.safetensors")
|
||||
async with http.post(api_base + "/api/assets", data=form) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 400
|
||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_ok_duplicate_reference(http: aiohttp.ClientSession, api_base: str, make_asset_bytes):
|
||||
name = "dup_a.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "dup"}
|
||||
data = make_asset_bytes(name)
|
||||
form1 = aiohttp.FormData()
|
||||
form1.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form1.add_field("tags", json.dumps(tags))
|
||||
form1.add_field("name", name)
|
||||
form1.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form1) as r1:
|
||||
a1 = await r1.json()
|
||||
assert r1.status == 201, a1
|
||||
assert a1["created_new"] is True
|
||||
|
||||
# Second upload with the same data and name should return created_new == False and the same asset
|
||||
form2 = aiohttp.FormData()
|
||||
form2.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form2.add_field("tags", json.dumps(tags))
|
||||
form2.add_field("name", name)
|
||||
form2.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form2) as r2:
|
||||
a2 = await r2.json()
|
||||
assert r2.status == 200, a2
|
||||
assert a2["created_new"] is False
|
||||
assert a2["asset_hash"] == a1["asset_hash"]
|
||||
assert a2["id"] == a1["id"] # old reference
|
||||
|
||||
# Third upload with the same data but new name should return created_new == False and the new AssetReference
|
||||
form3 = aiohttp.FormData()
|
||||
form3.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||
form3.add_field("tags", json.dumps(tags))
|
||||
form3.add_field("name", name + "_d")
|
||||
form3.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form3) as r2:
|
||||
a3 = await r2.json()
|
||||
assert r2.status == 200, a3
|
||||
assert a3["created_new"] is False
|
||||
assert a3["asset_hash"] == a1["asset_hash"]
|
||||
assert a3["id"] != a1["id"] # old reference
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_upload_fastpath_from_existing_hash_no_file(http: aiohttp.ClientSession, api_base: str):
|
||||
# Seed a small file first
|
||||
name = "fastpath_seed.safetensors"
|
||||
tags = ["models", "checkpoints", "unit-tests"]
|
||||
meta = {}
|
||||
form1 = aiohttp.FormData()
|
||||
form1.add_field("file", b"B" * 1024, filename=name, content_type="application/octet-stream")
|
||||
form1.add_field("tags", json.dumps(tags))
|
||||
form1.add_field("name", name)
|
||||
form1.add_field("user_metadata", json.dumps(meta))
|
||||
async with http.post(api_base + "/api/assets", data=form1) as r1:
|
||||
b1 = await r1.json()
|
||||
assert r1.status == 201, b1
|
||||
h = b1["asset_hash"]
|
||||
|
||||
# Now POST /api/assets with only hash and no file
|
||||
form2 = aiohttp.FormData()
|
||||
form2.add_field("hash", h)
|
||||
form2.add_field("tags", json.dumps(tags))
|
||||
form2.add_field("name", "fastpath_copy.safetensors")
|
||||
form2.add_field("user_metadata", json.dumps({"purpose": "copy"}))
|
||||
async with http.post(api_base + "/api/assets", data=form2) as r2:
|
||||
b2 = await r2.json()
|
||||
assert r2.status == 200, b2 # fast path returns 200 with created_new == False
|
||||
assert b2["created_new"] is False
|
||||
assert b2["asset_hash"] == h
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str):
|
||||
payload = {
|
||||
"hash": "blake3:" + "0" * 64,
|
||||
"name": "nonexistent.bin",
|
||||
"tags": ["models", "checkpoints", "unit-tests"],
|
||||
}
|
||||
async with http.post(api_base + "/api/assets/from-hash", json=payload) as r:
|
||||
body = await r.json()
|
||||
assert r.status == 404
|
||||
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||
Loading…
Reference in New Issue
Block a user