mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2025-12-18 18:43:05 +08:00
concurrency upload test + fixed 2 related bugs
This commit is contained in:
parent
4a713654cd
commit
975650060f
@ -1,6 +1,8 @@
|
|||||||
from typing import Iterable
|
from typing import Iterable
|
||||||
|
|
||||||
from sqlalchemy import delete, select
|
from sqlalchemy import delete, select
|
||||||
|
from sqlalchemy.dialects import postgresql as d_pg
|
||||||
|
from sqlalchemy.dialects import sqlite as d_sqlite
|
||||||
from sqlalchemy.ext.asyncio import AsyncSession
|
from sqlalchemy.ext.asyncio import AsyncSession
|
||||||
|
|
||||||
from ..._assets_helpers import normalize_tags
|
from ..._assets_helpers import normalize_tags
|
||||||
@ -13,13 +15,29 @@ async def ensure_tags_exist(session: AsyncSession, names: Iterable[str], tag_typ
|
|||||||
if not wanted:
|
if not wanted:
|
||||||
return []
|
return []
|
||||||
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||||
|
existing_names = {t.name for t in existing}
|
||||||
|
missing = [n for n in wanted if n not in existing_names]
|
||||||
|
if missing:
|
||||||
|
dialect = session.bind.dialect.name
|
||||||
|
rows = [{"name": n, "tag_type": tag_type} for n in missing]
|
||||||
|
if dialect == "sqlite":
|
||||||
|
ins = (
|
||||||
|
d_sqlite.insert(Tag)
|
||||||
|
.values(rows)
|
||||||
|
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||||
|
)
|
||||||
|
elif dialect == "postgresql":
|
||||||
|
ins = (
|
||||||
|
d_pg.insert(Tag)
|
||||||
|
.values(rows)
|
||||||
|
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||||
|
await session.execute(ins)
|
||||||
|
existing = (await session.execute(select(Tag).where(Tag.name.in_(wanted)))).scalars().all()
|
||||||
by_name = {t.name: t for t in existing}
|
by_name = {t.name: t for t in existing}
|
||||||
to_create = [Tag(name=n, tag_type=tag_type) for n in wanted if n not in by_name]
|
return [by_name[n] for n in wanted if n in by_name]
|
||||||
if to_create:
|
|
||||||
session.add_all(to_create)
|
|
||||||
await session.flush()
|
|
||||||
by_name.update({t.name: t for t in to_create})
|
|
||||||
return [by_name[n] for n in wanted]
|
|
||||||
|
|
||||||
|
|
||||||
async def add_missing_tag_for_asset_id(
|
async def add_missing_tag_for_asset_id(
|
||||||
|
|||||||
@ -484,6 +484,7 @@ async def ingest_fs_asset(
|
|||||||
"""
|
"""
|
||||||
locator = os.path.abspath(abs_path)
|
locator = os.path.abspath(abs_path)
|
||||||
now = utcnow()
|
now = utcnow()
|
||||||
|
dialect = session.bind.dialect.name
|
||||||
|
|
||||||
if preview_id:
|
if preview_id:
|
||||||
if not await session.get(Asset, preview_id):
|
if not await session.get(Asset, preview_id):
|
||||||
@ -502,10 +503,34 @@ async def ingest_fs_asset(
|
|||||||
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||||
).scalars().first()
|
).scalars().first()
|
||||||
if not asset:
|
if not asset:
|
||||||
async with session.begin_nested():
|
vals = {
|
||||||
asset = Asset(hash=asset_hash, size_bytes=int(size_bytes), mime_type=mime_type, created_at=now)
|
"hash": asset_hash,
|
||||||
session.add(asset)
|
"size_bytes": int(size_bytes),
|
||||||
await session.flush()
|
"mime_type": mime_type,
|
||||||
|
"created_at": now,
|
||||||
|
}
|
||||||
|
if dialect == "sqlite":
|
||||||
|
ins = (
|
||||||
|
d_sqlite.insert(Asset)
|
||||||
|
.values(**vals)
|
||||||
|
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||||
|
)
|
||||||
|
elif dialect == "postgresql":
|
||||||
|
ins = (
|
||||||
|
d_pg.insert(Asset)
|
||||||
|
.values(**vals)
|
||||||
|
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise NotImplementedError(f"Unsupported database dialect: {dialect}")
|
||||||
|
res = await session.execute(ins)
|
||||||
|
rowcount = int(res.rowcount or 0)
|
||||||
|
asset = (
|
||||||
|
await session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||||
|
).scalars().first()
|
||||||
|
if not asset:
|
||||||
|
raise RuntimeError("Asset row not found after upsert.")
|
||||||
|
if rowcount > 0:
|
||||||
out["asset_created"] = True
|
out["asset_created"] = True
|
||||||
else:
|
else:
|
||||||
changed = False
|
changed = False
|
||||||
@ -524,7 +549,6 @@ async def ingest_fs_asset(
|
|||||||
"file_path": locator,
|
"file_path": locator,
|
||||||
"mtime_ns": int(mtime_ns),
|
"mtime_ns": int(mtime_ns),
|
||||||
}
|
}
|
||||||
dialect = session.bind.dialect.name
|
|
||||||
if dialect == "sqlite":
|
if dialect == "sqlite":
|
||||||
ins = (
|
ins = (
|
||||||
d_sqlite.insert(AssetCacheState)
|
d_sqlite.insert(AssetCacheState)
|
||||||
|
|||||||
@ -1,4 +1,7 @@
|
|||||||
|
import asyncio
|
||||||
import json
|
import json
|
||||||
|
import uuid
|
||||||
|
|
||||||
import aiohttp
|
import aiohttp
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
@ -125,6 +128,54 @@ async def test_upload_multiple_tags_fields_are_merged(http: aiohttp.ClientSessio
|
|||||||
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
async def test_concurrent_upload_identical_bytes_different_names(
|
||||||
|
root: str,
|
||||||
|
http: aiohttp.ClientSession,
|
||||||
|
api_base: str,
|
||||||
|
make_asset_bytes,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Two concurrent uploads of identical bytes but different names.
|
||||||
|
Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True.
|
||||||
|
"""
|
||||||
|
scope = f"concupload-{uuid.uuid4().hex[:6]}"
|
||||||
|
name1, name2 = "cu_a.bin", "cu_b.bin"
|
||||||
|
data = make_asset_bytes("concurrent", 4096)
|
||||||
|
tags = [root, "unit-tests", scope]
|
||||||
|
|
||||||
|
def _form(name: str) -> aiohttp.FormData:
|
||||||
|
f = aiohttp.FormData()
|
||||||
|
f.add_field("file", data, filename=name, content_type="application/octet-stream")
|
||||||
|
f.add_field("tags", json.dumps(tags))
|
||||||
|
f.add_field("name", name)
|
||||||
|
f.add_field("user_metadata", json.dumps({}))
|
||||||
|
return f
|
||||||
|
|
||||||
|
r1, r2 = await asyncio.gather(
|
||||||
|
http.post(api_base + "/api/assets", data=_form(name1)),
|
||||||
|
http.post(api_base + "/api/assets", data=_form(name2)),
|
||||||
|
)
|
||||||
|
b1, b2 = await r1.json(), await r2.json()
|
||||||
|
assert r1.status in (200, 201), b1
|
||||||
|
assert r2.status in (200, 201), b2
|
||||||
|
assert b1["asset_hash"] == b2["asset_hash"]
|
||||||
|
assert b1["id"] != b2["id"]
|
||||||
|
|
||||||
|
created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))])
|
||||||
|
assert created_flags == [False, True]
|
||||||
|
|
||||||
|
async with http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": f"unit-tests,{scope}", "sort": "name"},
|
||||||
|
) as rl:
|
||||||
|
bl = await rl.json()
|
||||||
|
assert rl.status == 200, bl
|
||||||
|
names = [a["name"] for a in bl.get("assets", [])]
|
||||||
|
assert set([name1, name2]).issubset(names)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str):
|
async def test_create_from_hash_endpoint_404(http: aiohttp.ClientSession, api_base: str):
|
||||||
payload = {
|
payload = {
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user