mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-02-06 03:22:33 +08:00
Assets Part 2 - add more endpoints (#12125)
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Execution Tests / test (macos-latest) (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.10, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.11, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-stable (12.1, , linux, 3.12, [self-hosted Linux], stable) (push) Waiting to run
Full Comfy CI Workflow Runs / test-unix-nightly (12.1, , linux, 3.11, [self-hosted Linux], nightly) (push) Waiting to run
Execution Tests / test (ubuntu-latest) (push) Waiting to run
Execution Tests / test (windows-latest) (push) Waiting to run
Test server launches without errors / test (push) Waiting to run
Unit Tests / test (macos-latest) (push) Waiting to run
Unit Tests / test (ubuntu-latest) (push) Waiting to run
Unit Tests / test (windows-2022) (push) Waiting to run
This commit is contained in:
parent
6e469a3f35
commit
6ea8c128a3
@ -1,5 +1,8 @@
|
|||||||
import logging
|
import logging
|
||||||
import uuid
|
import uuid
|
||||||
|
import urllib.parse
|
||||||
|
import os
|
||||||
|
import contextlib
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
|
|
||||||
from pydantic import ValidationError
|
from pydantic import ValidationError
|
||||||
@ -8,6 +11,9 @@ import app.assets.manager as manager
|
|||||||
from app import user_manager
|
from app import user_manager
|
||||||
from app.assets.api import schemas_in
|
from app.assets.api import schemas_in
|
||||||
from app.assets.helpers import get_query_dict
|
from app.assets.helpers import get_query_dict
|
||||||
|
from app.assets.scanner import seed_assets
|
||||||
|
|
||||||
|
import folder_paths
|
||||||
|
|
||||||
ROUTES = web.RouteTableDef()
|
ROUTES = web.RouteTableDef()
|
||||||
USER_MANAGER: user_manager.UserManager | None = None
|
USER_MANAGER: user_manager.UserManager | None = None
|
||||||
@ -15,6 +21,9 @@ USER_MANAGER: user_manager.UserManager | None = None
|
|||||||
# UUID regex (canonical hyphenated form, case-insensitive)
|
# UUID regex (canonical hyphenated form, case-insensitive)
|
||||||
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}"
|
||||||
|
|
||||||
|
# Note to any custom node developers reading this code:
|
||||||
|
# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same.
|
||||||
|
|
||||||
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None:
|
||||||
global USER_MANAGER
|
global USER_MANAGER
|
||||||
USER_MANAGER = user_manager_instance
|
USER_MANAGER = user_manager_instance
|
||||||
@ -28,6 +37,18 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response:
|
|||||||
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
return _error_response(400, code, "Validation failed.", {"errors": ve.json()})
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.head("/api/assets/hash/{hash}")
|
||||||
|
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||||
|
hash_str = request.match_info.get("hash", "").strip().lower()
|
||||||
|
if not hash_str or ":" not in hash_str:
|
||||||
|
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||||
|
algo, digest = hash_str.split(":", 1)
|
||||||
|
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||||
|
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||||
|
exists = manager.asset_exists(asset_hash=hash_str)
|
||||||
|
return web.Response(status=200 if exists else 404)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.get("/api/assets")
|
@ROUTES.get("/api/assets")
|
||||||
async def list_assets(request: web.Request) -> web.Response:
|
async def list_assets(request: web.Request) -> web.Response:
|
||||||
"""
|
"""
|
||||||
@ -50,7 +71,7 @@ async def list_assets(request: web.Request) -> web.Response:
|
|||||||
order=q.order,
|
order=q.order,
|
||||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return web.json_response(payload.model_dump(mode="json"))
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
@ -76,6 +97,314 @@ async def get_asset(request: web.Request) -> web.Response:
|
|||||||
return web.json_response(result.model_dump(mode="json"), status=200)
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||||
|
async def download_asset_content(request: web.Request) -> web.Response:
|
||||||
|
# question: do we need disposition? could we just stick with one of these?
|
||||||
|
disposition = request.query.get("disposition", "attachment").lower().strip()
|
||||||
|
if disposition not in {"inline", "attachment"}:
|
||||||
|
disposition = "attachment"
|
||||||
|
|
||||||
|
try:
|
||||||
|
abs_path, content_type, filename = manager.resolve_asset_content_for_download(
|
||||||
|
asset_info_id=str(uuid.UUID(request.match_info["id"])),
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
except ValueError as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve))
|
||||||
|
except NotImplementedError as nie:
|
||||||
|
return _error_response(501, "BACKEND_UNSUPPORTED", str(nie))
|
||||||
|
except FileNotFoundError:
|
||||||
|
return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.")
|
||||||
|
|
||||||
|
quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'")
|
||||||
|
cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}'
|
||||||
|
|
||||||
|
file_size = os.path.getsize(abs_path)
|
||||||
|
logging.info(
|
||||||
|
"download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s",
|
||||||
|
abs_path,
|
||||||
|
file_size,
|
||||||
|
file_size / (1024 * 1024),
|
||||||
|
content_type,
|
||||||
|
filename,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def file_sender():
|
||||||
|
chunk_size = 64 * 1024
|
||||||
|
with open(abs_path, "rb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = f.read(chunk_size)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
return web.Response(
|
||||||
|
body=file_sender(),
|
||||||
|
content_type=content_type,
|
||||||
|
headers={
|
||||||
|
"Content-Disposition": cd,
|
||||||
|
"Content-Length": str(file_size),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/assets/from-hash")
|
||||||
|
async def create_asset_from_hash(request: web.Request) -> web.Response:
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
body = schemas_in.CreateFromHashBody.model_validate(payload)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _validation_error_response("INVALID_BODY", ve)
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
result = manager.create_asset_from_hash(
|
||||||
|
hash_str=body.hash,
|
||||||
|
name=body.name,
|
||||||
|
tags=body.tags,
|
||||||
|
user_metadata=body.user_metadata,
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
if result is None:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist")
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=201)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/assets")
|
||||||
|
async def upload_asset(request: web.Request) -> web.Response:
|
||||||
|
"""Multipart/form-data endpoint for Asset uploads."""
|
||||||
|
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||||
|
return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.")
|
||||||
|
|
||||||
|
reader = await request.multipart()
|
||||||
|
|
||||||
|
file_present = False
|
||||||
|
file_client_name: str | None = None
|
||||||
|
tags_raw: list[str] = []
|
||||||
|
provided_name: str | None = None
|
||||||
|
user_metadata_raw: str | None = None
|
||||||
|
provided_hash: str | None = None
|
||||||
|
provided_hash_exists: bool | None = None
|
||||||
|
|
||||||
|
file_written = 0
|
||||||
|
tmp_path: str | None = None
|
||||||
|
while True:
|
||||||
|
field = await reader.next()
|
||||||
|
if field is None:
|
||||||
|
break
|
||||||
|
|
||||||
|
fname = getattr(field, "name", "") or ""
|
||||||
|
|
||||||
|
if fname == "hash":
|
||||||
|
try:
|
||||||
|
s = ((await field.text()) or "").strip().lower()
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||||
|
|
||||||
|
if s:
|
||||||
|
if ":" not in s:
|
||||||
|
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||||
|
algo, digest = s.split(":", 1)
|
||||||
|
if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||||
|
return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||||
|
provided_hash = f"{algo}:{digest}"
|
||||||
|
try:
|
||||||
|
provided_hash_exists = manager.asset_exists(asset_hash=provided_hash)
|
||||||
|
except Exception:
|
||||||
|
provided_hash_exists = None # do not fail the whole request here
|
||||||
|
|
||||||
|
elif fname == "file":
|
||||||
|
file_present = True
|
||||||
|
file_client_name = (field.filename or "").strip()
|
||||||
|
|
||||||
|
if provided_hash and provided_hash_exists is True:
|
||||||
|
# If client supplied a hash that we know exists, drain but do not write to disk
|
||||||
|
try:
|
||||||
|
while True:
|
||||||
|
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
file_written += len(chunk)
|
||||||
|
except Exception:
|
||||||
|
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.")
|
||||||
|
continue # Do not create temp file; we will create AssetInfo from the existing content
|
||||||
|
|
||||||
|
# Otherwise, store to temp for hashing/ingest
|
||||||
|
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||||
|
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||||
|
os.makedirs(unique_dir, exist_ok=True)
|
||||||
|
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(tmp_path, "wb") as f:
|
||||||
|
while True:
|
||||||
|
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||||
|
if not chunk:
|
||||||
|
break
|
||||||
|
f.write(chunk)
|
||||||
|
file_written += len(chunk)
|
||||||
|
except Exception:
|
||||||
|
try:
|
||||||
|
if os.path.exists(tmp_path or ""):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
finally:
|
||||||
|
return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.")
|
||||||
|
elif fname == "tags":
|
||||||
|
tags_raw.append((await field.text()) or "")
|
||||||
|
elif fname == "name":
|
||||||
|
provided_name = (await field.text()) or None
|
||||||
|
elif fname == "user_metadata":
|
||||||
|
user_metadata_raw = (await field.text()) or None
|
||||||
|
|
||||||
|
# If client did not send file, and we are not doing a from-hash fast path -> error
|
||||||
|
if not file_present and not (provided_hash and provided_hash_exists):
|
||||||
|
return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.")
|
||||||
|
|
||||||
|
if file_present and file_written == 0 and not (provided_hash and provided_hash_exists):
|
||||||
|
# Empty upload is only acceptable if we are fast-pathing from existing hash
|
||||||
|
try:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
finally:
|
||||||
|
return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
spec = schemas_in.UploadAssetSpec.model_validate({
|
||||||
|
"tags": tags_raw,
|
||||||
|
"name": provided_name,
|
||||||
|
"user_metadata": user_metadata_raw,
|
||||||
|
"hash": provided_hash,
|
||||||
|
})
|
||||||
|
except ValidationError as ve:
|
||||||
|
try:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
finally:
|
||||||
|
return _validation_error_response("INVALID_BODY", ve)
|
||||||
|
|
||||||
|
# Validate models category against configured folders (consistent with previous behavior)
|
||||||
|
if spec.tags and spec.tags[0] == "models":
|
||||||
|
if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
return _error_response(
|
||||||
|
400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'"
|
||||||
|
)
|
||||||
|
|
||||||
|
owner_id = USER_MANAGER.get_request_user_id(request)
|
||||||
|
|
||||||
|
# Fast path: if a valid provided hash exists, create AssetInfo without writing anything
|
||||||
|
if spec.hash and provided_hash_exists is True:
|
||||||
|
try:
|
||||||
|
result = manager.create_asset_from_hash(
|
||||||
|
hash_str=spec.hash,
|
||||||
|
name=spec.name or (spec.hash.split(":", 1)[1]),
|
||||||
|
tags=spec.tags,
|
||||||
|
user_metadata=spec.user_metadata or {},
|
||||||
|
owner_id=owner_id,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
if result is None:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist")
|
||||||
|
|
||||||
|
# Drain temp if we accidentally saved (e.g., hash field came after file)
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
|
||||||
|
status = 200 if (not result.created_new) else 201
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=status)
|
||||||
|
|
||||||
|
# Otherwise, we must have a temp file path to ingest
|
||||||
|
if not tmp_path or not os.path.exists(tmp_path):
|
||||||
|
# The only case we reach here without a temp file is: client sent a hash that does not exist and no file
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
created = manager.upload_asset_from_temp_path(
|
||||||
|
spec,
|
||||||
|
temp_path=tmp_path,
|
||||||
|
client_filename=file_client_name,
|
||||||
|
owner_id=owner_id,
|
||||||
|
expected_asset_hash=spec.hash,
|
||||||
|
)
|
||||||
|
status = 201 if created.created_new else 200
|
||||||
|
return web.json_response(created.model_dump(mode="json"), status=status)
|
||||||
|
except ValueError as e:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
msg = str(e)
|
||||||
|
if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH":
|
||||||
|
return _error_response(
|
||||||
|
400,
|
||||||
|
"HASH_MISMATCH",
|
||||||
|
"Uploaded file hash does not match provided hash.",
|
||||||
|
)
|
||||||
|
return _error_response(400, "BAD_REQUEST", "Invalid inputs.")
|
||||||
|
except Exception:
|
||||||
|
if tmp_path and os.path.exists(tmp_path):
|
||||||
|
os.remove(tmp_path)
|
||||||
|
logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
|
async def update_asset(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||||
|
try:
|
||||||
|
body = schemas_in.UpdateAssetBody.model_validate(await request.json())
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _validation_error_response("INVALID_BODY", ve)
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = manager.update_asset(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
name=body.name,
|
||||||
|
user_metadata=body.user_metadata,
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
except (ValueError, PermissionError) as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
"update_asset failed for asset_info_id=%s, owner_id=%s",
|
||||||
|
asset_info_id,
|
||||||
|
USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
|
async def delete_asset(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||||
|
delete_content = request.query.get("delete_content")
|
||||||
|
delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"}
|
||||||
|
|
||||||
|
try:
|
||||||
|
deleted = manager.delete_asset_reference(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
delete_content_if_orphan=delete_content,
|
||||||
|
)
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
"delete_asset_reference failed for asset_info_id=%s, owner_id=%s",
|
||||||
|
asset_info_id,
|
||||||
|
USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
if not deleted:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.")
|
||||||
|
return web.Response(status=204)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.get("/api/tags")
|
@ROUTES.get("/api/tags")
|
||||||
async def get_tags(request: web.Request) -> web.Response:
|
async def get_tags(request: web.Request) -> web.Response:
|
||||||
"""
|
"""
|
||||||
@ -100,3 +429,86 @@ async def get_tags(request: web.Request) -> web.Response:
|
|||||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return web.json_response(result.model_dump(mode="json"))
|
return web.json_response(result.model_dump(mode="json"))
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||||
|
async def add_asset_tags(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
data = schemas_in.TagsAdd.model_validate(payload)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()})
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = manager.add_tags_to_asset(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=data.tags,
|
||||||
|
origin="manual",
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
except (ValueError, PermissionError) as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
"add_tags_to_asset failed for asset_info_id=%s, owner_id=%s",
|
||||||
|
asset_info_id,
|
||||||
|
USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||||
|
async def delete_asset_tags(request: web.Request) -> web.Response:
|
||||||
|
asset_info_id = str(uuid.UUID(request.match_info["id"]))
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
data = schemas_in.TagsRemove.model_validate(payload)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()})
|
||||||
|
except Exception:
|
||||||
|
return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
result = manager.remove_tags_from_asset(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=data.tags,
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
except ValueError as ve:
|
||||||
|
return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id})
|
||||||
|
except Exception:
|
||||||
|
logging.exception(
|
||||||
|
"remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s",
|
||||||
|
asset_info_id,
|
||||||
|
USER_MANAGER.get_request_user_id(request),
|
||||||
|
)
|
||||||
|
return _error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
|
return web.json_response(result.model_dump(mode="json"), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.post("/api/assets/seed")
|
||||||
|
async def seed_assets_endpoint(request: web.Request) -> web.Response:
|
||||||
|
"""Trigger asset seeding for specified roots (models, input, output)."""
|
||||||
|
try:
|
||||||
|
payload = await request.json()
|
||||||
|
roots = payload.get("roots", ["models", "input", "output"])
|
||||||
|
except Exception:
|
||||||
|
roots = ["models", "input", "output"]
|
||||||
|
|
||||||
|
valid_roots = [r for r in roots if r in ("models", "input", "output")]
|
||||||
|
if not valid_roots:
|
||||||
|
return _error_response(400, "INVALID_BODY", "No valid roots specified")
|
||||||
|
|
||||||
|
try:
|
||||||
|
seed_assets(tuple(valid_roots))
|
||||||
|
except Exception:
|
||||||
|
logging.exception("seed_assets failed for roots=%s", valid_roots)
|
||||||
|
return _error_response(500, "INTERNAL", "Seed operation failed")
|
||||||
|
|
||||||
|
return web.json_response({"seeded": valid_roots}, status=200)
|
||||||
|
|||||||
@ -1,5 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
import uuid
|
|
||||||
from typing import Any, Literal
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
@ -8,9 +7,9 @@ from pydantic import (
|
|||||||
Field,
|
Field,
|
||||||
conint,
|
conint,
|
||||||
field_validator,
|
field_validator,
|
||||||
|
model_validator,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
class ListAssetsQuery(BaseModel):
|
class ListAssetsQuery(BaseModel):
|
||||||
include_tags: list[str] = Field(default_factory=list)
|
include_tags: list[str] = Field(default_factory=list)
|
||||||
exclude_tags: list[str] = Field(default_factory=list)
|
exclude_tags: list[str] = Field(default_factory=list)
|
||||||
@ -57,6 +56,57 @@ class ListAssetsQuery(BaseModel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
class UpdateAssetBody(BaseModel):
|
||||||
|
name: str | None = None
|
||||||
|
user_metadata: dict[str, Any] | None = None
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _at_least_one(self):
|
||||||
|
if self.name is None and self.user_metadata is None:
|
||||||
|
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
class CreateFromHashBody(BaseModel):
|
||||||
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
|
hash: str
|
||||||
|
name: str
|
||||||
|
tags: list[str] = Field(default_factory=list)
|
||||||
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
@field_validator("hash")
|
||||||
|
@classmethod
|
||||||
|
def _require_blake3(cls, v):
|
||||||
|
s = (v or "").strip().lower()
|
||||||
|
if ":" not in s:
|
||||||
|
raise ValueError("hash must be 'blake3:<hex>'")
|
||||||
|
algo, digest = s.split(":", 1)
|
||||||
|
if algo != "blake3":
|
||||||
|
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||||
|
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||||
|
raise ValueError("hash digest must be lowercase hex")
|
||||||
|
return s
|
||||||
|
|
||||||
|
@field_validator("tags", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _tags_norm(cls, v):
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, list):
|
||||||
|
out = [str(t).strip().lower() for t in v if str(t).strip()]
|
||||||
|
seen = set()
|
||||||
|
dedup = []
|
||||||
|
for t in out:
|
||||||
|
if t not in seen:
|
||||||
|
seen.add(t)
|
||||||
|
dedup.append(t)
|
||||||
|
return dedup
|
||||||
|
if isinstance(v, str):
|
||||||
|
return [t.strip().lower() for t in v.split(",") if t.strip()]
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
class TagsListQuery(BaseModel):
|
class TagsListQuery(BaseModel):
|
||||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
@ -75,20 +125,140 @@ class TagsListQuery(BaseModel):
|
|||||||
return v.lower() or None
|
return v.lower() or None
|
||||||
|
|
||||||
|
|
||||||
class SetPreviewBody(BaseModel):
|
class TagsAdd(BaseModel):
|
||||||
"""Set or clear the preview for an AssetInfo. Provide an Asset.id or null."""
|
model_config = ConfigDict(extra="ignore")
|
||||||
preview_id: str | None = None
|
tags: list[str] = Field(..., min_length=1)
|
||||||
|
|
||||||
@field_validator("preview_id", mode="before")
|
@field_validator("tags")
|
||||||
@classmethod
|
@classmethod
|
||||||
def _norm_uuid(cls, v):
|
def normalize_tags(cls, v: list[str]) -> list[str]:
|
||||||
|
out = []
|
||||||
|
for t in v:
|
||||||
|
if not isinstance(t, str):
|
||||||
|
raise TypeError("tags must be strings")
|
||||||
|
tnorm = t.strip().lower()
|
||||||
|
if tnorm:
|
||||||
|
out.append(tnorm)
|
||||||
|
seen = set()
|
||||||
|
deduplicated = []
|
||||||
|
for x in out:
|
||||||
|
if x not in seen:
|
||||||
|
seen.add(x)
|
||||||
|
deduplicated.append(x)
|
||||||
|
return deduplicated
|
||||||
|
|
||||||
|
|
||||||
|
class TagsRemove(TagsAdd):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class UploadAssetSpec(BaseModel):
|
||||||
|
"""Upload Asset operation.
|
||||||
|
- tags: ordered; first is root ('models'|'input'|'output');
|
||||||
|
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||||
|
- name: display name
|
||||||
|
- user_metadata: arbitrary JSON object (optional)
|
||||||
|
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||||
|
|
||||||
|
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||||
|
and the original extension is preserved when available.
|
||||||
|
"""
|
||||||
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
|
tags: list[str] = Field(..., min_length=1)
|
||||||
|
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||||
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
hash: str | None = Field(default=None)
|
||||||
|
|
||||||
|
@field_validator("hash", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_hash(cls, v):
|
||||||
if v is None:
|
if v is None:
|
||||||
return None
|
return None
|
||||||
s = str(v).strip()
|
s = str(v).strip().lower()
|
||||||
if not s:
|
if not s:
|
||||||
return None
|
return None
|
||||||
try:
|
if ":" not in s:
|
||||||
uuid.UUID(s)
|
raise ValueError("hash must be 'blake3:<hex>'")
|
||||||
except Exception:
|
algo, digest = s.split(":", 1)
|
||||||
raise ValueError("preview_id must be a UUID")
|
if algo != "blake3":
|
||||||
return s
|
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||||
|
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||||
|
raise ValueError("hash digest must be lowercase hex")
|
||||||
|
return f"{algo}:{digest}"
|
||||||
|
|
||||||
|
@field_validator("tags", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_tags(cls, v):
|
||||||
|
"""
|
||||||
|
Accepts a list of strings (possibly multiple form fields),
|
||||||
|
where each string can be:
|
||||||
|
- JSON array (e.g., '["models","loras","foo"]')
|
||||||
|
- comma-separated ('models, loras, foo')
|
||||||
|
- single token ('models')
|
||||||
|
Returns a normalized, deduplicated, ordered list.
|
||||||
|
"""
|
||||||
|
items: list[str] = []
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, str):
|
||||||
|
v = [v]
|
||||||
|
|
||||||
|
if isinstance(v, list):
|
||||||
|
for item in v:
|
||||||
|
if item is None:
|
||||||
|
continue
|
||||||
|
s = str(item).strip()
|
||||||
|
if not s:
|
||||||
|
continue
|
||||||
|
if s.startswith("["):
|
||||||
|
try:
|
||||||
|
arr = json.loads(s)
|
||||||
|
if isinstance(arr, list):
|
||||||
|
items.extend(str(x) for x in arr)
|
||||||
|
continue
|
||||||
|
except Exception:
|
||||||
|
pass # fallback to CSV parse below
|
||||||
|
items.extend([p for p in s.split(",") if p.strip()])
|
||||||
|
else:
|
||||||
|
return []
|
||||||
|
|
||||||
|
# normalize + dedupe
|
||||||
|
norm = []
|
||||||
|
seen = set()
|
||||||
|
for t in items:
|
||||||
|
tnorm = str(t).strip().lower()
|
||||||
|
if tnorm and tnorm not in seen:
|
||||||
|
seen.add(tnorm)
|
||||||
|
norm.append(tnorm)
|
||||||
|
return norm
|
||||||
|
|
||||||
|
@field_validator("user_metadata", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_metadata_json(cls, v):
|
||||||
|
if v is None or isinstance(v, dict):
|
||||||
|
return v or {}
|
||||||
|
if isinstance(v, str):
|
||||||
|
s = v.strip()
|
||||||
|
if not s:
|
||||||
|
return {}
|
||||||
|
try:
|
||||||
|
parsed = json.loads(s)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"user_metadata must be JSON: {e}") from e
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise ValueError("user_metadata must be a JSON object")
|
||||||
|
return parsed
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _validate_order(self):
|
||||||
|
if not self.tags:
|
||||||
|
raise ValueError("tags must be provided and non-empty")
|
||||||
|
root = self.tags[0]
|
||||||
|
if root not in {"models", "input", "output"}:
|
||||||
|
raise ValueError("first tag must be one of: models, input, output")
|
||||||
|
if root == "models":
|
||||||
|
if len(self.tags) < 2:
|
||||||
|
raise ValueError("models uploads require a category tag as the second tag")
|
||||||
|
return self
|
||||||
|
|||||||
@ -29,6 +29,21 @@ class AssetsList(BaseModel):
|
|||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class AssetUpdated(BaseModel):
|
||||||
|
id: str
|
||||||
|
name: str
|
||||||
|
asset_hash: str | None = None
|
||||||
|
tags: list[str] = Field(default_factory=list)
|
||||||
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
updated_at: datetime | None = None
|
||||||
|
|
||||||
|
model_config = ConfigDict(from_attributes=True)
|
||||||
|
|
||||||
|
@field_serializer("updated_at")
|
||||||
|
def _ser_updated(self, v: datetime | None, _info):
|
||||||
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
class AssetDetail(BaseModel):
|
class AssetDetail(BaseModel):
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
@ -48,6 +63,10 @@ class AssetDetail(BaseModel):
|
|||||||
return v.isoformat() if v else None
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
|
class AssetCreated(AssetDetail):
|
||||||
|
created_new: bool
|
||||||
|
|
||||||
|
|
||||||
class TagUsage(BaseModel):
|
class TagUsage(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
@ -58,3 +77,17 @@ class TagsList(BaseModel):
|
|||||||
tags: list[TagUsage] = Field(default_factory=list)
|
tags: list[TagUsage] = Field(default_factory=list)
|
||||||
total: int
|
total: int
|
||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
|
class TagsAdd(BaseModel):
|
||||||
|
model_config = ConfigDict(str_strip_whitespace=True)
|
||||||
|
added: list[str] = Field(default_factory=list)
|
||||||
|
already_present: list[str] = Field(default_factory=list)
|
||||||
|
total_tags: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TagsRemove(BaseModel):
|
||||||
|
model_config = ConfigDict(str_strip_whitespace=True)
|
||||||
|
removed: list[str] = Field(default_factory=list)
|
||||||
|
not_present: list[str] = Field(default_factory=list)
|
||||||
|
total_tags: list[str] = Field(default_factory=list)
|
||||||
|
|||||||
@ -1,9 +1,17 @@
|
|||||||
|
import os
|
||||||
|
import logging
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from sqlalchemy import select, exists, func
|
from datetime import datetime
|
||||||
|
from typing import Iterable, Any
|
||||||
|
from sqlalchemy import select, delete, exists, func
|
||||||
|
from sqlalchemy.dialects import sqlite
|
||||||
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session, contains_eager, noload
|
from sqlalchemy.orm import Session, contains_eager, noload
|
||||||
from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag
|
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||||
from app.assets.helpers import escape_like_prefix, normalize_tags
|
from app.assets.helpers import (
|
||||||
|
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||||
|
)
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
|
|
||||||
@ -15,6 +23,22 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
|||||||
return AssetInfo.owner_id.in_(["", owner_id])
|
return AssetInfo.owner_id.in_(["", owner_id])
|
||||||
|
|
||||||
|
|
||||||
|
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||||
|
"""
|
||||||
|
Return the best on-disk path among cache states:
|
||||||
|
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||||
|
2) Otherwise, pick the first path that exists.
|
||||||
|
3) Otherwise return empty string.
|
||||||
|
"""
|
||||||
|
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||||
|
if not alive:
|
||||||
|
return ""
|
||||||
|
for s in alive:
|
||||||
|
if not getattr(s, "needs_verify", False):
|
||||||
|
return s.file_path
|
||||||
|
return alive[0].file_path
|
||||||
|
|
||||||
|
|
||||||
def apply_tag_filters(
|
def apply_tag_filters(
|
||||||
stmt: sa.sql.Select,
|
stmt: sa.sql.Select,
|
||||||
include_tags: Sequence[str] | None = None,
|
include_tags: Sequence[str] | None = None,
|
||||||
@ -42,6 +66,7 @@ def apply_tag_filters(
|
|||||||
)
|
)
|
||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
def apply_metadata_filter(
|
def apply_metadata_filter(
|
||||||
stmt: sa.sql.Select,
|
stmt: sa.sql.Select,
|
||||||
metadata_filter: dict | None = None,
|
metadata_filter: dict | None = None,
|
||||||
@ -94,7 +119,11 @@ def apply_metadata_filter(
|
|||||||
return stmt
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
def asset_exists_by_hash(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if an asset with a given hash exists in database.
|
Check if an asset with a given hash exists in database.
|
||||||
"""
|
"""
|
||||||
@ -105,9 +134,39 @@ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool:
|
|||||||
).first()
|
).first()
|
||||||
return row is not None
|
return row is not None
|
||||||
|
|
||||||
def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None:
|
|
||||||
|
def asset_info_exists_for_asset_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_id: str,
|
||||||
|
) -> bool:
|
||||||
|
q = (
|
||||||
|
select(sa.literal(True))
|
||||||
|
.select_from(AssetInfo)
|
||||||
|
.where(AssetInfo.asset_id == asset_id)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
return (session.execute(q)).first() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def get_asset_by_hash(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
) -> Asset | None:
|
||||||
|
return (
|
||||||
|
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||||
|
).scalars().first()
|
||||||
|
|
||||||
|
|
||||||
|
def get_asset_info_by_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
) -> AssetInfo | None:
|
||||||
return session.get(AssetInfo, asset_info_id)
|
return session.get(AssetInfo, asset_info_id)
|
||||||
|
|
||||||
|
|
||||||
def list_asset_infos_page(
|
def list_asset_infos_page(
|
||||||
session: Session,
|
session: Session,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
@ -171,12 +230,14 @@ def list_asset_infos_page(
|
|||||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||||
|
.order_by(AssetInfoTag.added_at)
|
||||||
)
|
)
|
||||||
for aid, tag_name in rows.all():
|
for aid, tag_name in rows.all():
|
||||||
tag_map[aid].append(tag_name)
|
tag_map[aid].append(tag_name)
|
||||||
|
|
||||||
return infos, tag_map, total
|
return infos, tag_map, total
|
||||||
|
|
||||||
|
|
||||||
def fetch_asset_info_asset_and_tags(
|
def fetch_asset_info_asset_and_tags(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_info_id: str,
|
asset_info_id: str,
|
||||||
@ -208,6 +269,494 @@ def fetch_asset_info_asset_and_tags(
|
|||||||
tags.append(tag_name)
|
tags.append(tag_name)
|
||||||
return first_info, first_asset, tags
|
return first_info, first_asset, tags
|
||||||
|
|
||||||
|
|
||||||
|
def fetch_asset_info_and_asset(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> tuple[AssetInfo, Asset] | None:
|
||||||
|
stmt = (
|
||||||
|
select(AssetInfo, Asset)
|
||||||
|
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||||
|
.where(
|
||||||
|
AssetInfo.id == asset_info_id,
|
||||||
|
visible_owner_clause(owner_id),
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
.options(noload(AssetInfo.tags))
|
||||||
|
)
|
||||||
|
row = session.execute(stmt)
|
||||||
|
pair = row.first()
|
||||||
|
if not pair:
|
||||||
|
return None
|
||||||
|
return pair[0], pair[1]
|
||||||
|
|
||||||
|
def list_cache_states_by_asset_id(
|
||||||
|
session: Session, *, asset_id: str
|
||||||
|
) -> Sequence[AssetCacheState]:
|
||||||
|
return (
|
||||||
|
session.execute(
|
||||||
|
select(AssetCacheState)
|
||||||
|
.where(AssetCacheState.asset_id == asset_id)
|
||||||
|
.order_by(AssetCacheState.id.asc())
|
||||||
|
)
|
||||||
|
).scalars().all()
|
||||||
|
|
||||||
|
|
||||||
|
def touch_asset_info_by_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
ts: datetime | None = None,
|
||||||
|
only_if_newer: bool = True,
|
||||||
|
) -> None:
|
||||||
|
ts = ts or utcnow()
|
||||||
|
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||||
|
if only_if_newer:
|
||||||
|
stmt = stmt.where(
|
||||||
|
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||||
|
)
|
||||||
|
session.execute(stmt.values(last_access_time=ts))
|
||||||
|
|
||||||
|
|
||||||
|
def create_asset_info_for_existing_asset(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
name: str,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
tags: Sequence[str] | None = None,
|
||||||
|
tag_origin: str = "manual",
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> AssetInfo:
|
||||||
|
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||||
|
now = utcnow()
|
||||||
|
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||||
|
if not asset:
|
||||||
|
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||||
|
|
||||||
|
info = AssetInfo(
|
||||||
|
owner_id=owner_id,
|
||||||
|
name=name,
|
||||||
|
asset_id=asset.id,
|
||||||
|
preview_id=None,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
last_access_time=now,
|
||||||
|
)
|
||||||
|
try:
|
||||||
|
with session.begin_nested():
|
||||||
|
session.add(info)
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError:
|
||||||
|
existing = (
|
||||||
|
session.execute(
|
||||||
|
select(AssetInfo)
|
||||||
|
.options(noload(AssetInfo.tags))
|
||||||
|
.where(
|
||||||
|
AssetInfo.asset_id == asset.id,
|
||||||
|
AssetInfo.name == name,
|
||||||
|
AssetInfo.owner_id == owner_id,
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
).unique().scalars().first()
|
||||||
|
if not existing:
|
||||||
|
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||||
|
return existing
|
||||||
|
|
||||||
|
# metadata["filename"] hack
|
||||||
|
new_meta = dict(user_metadata or {})
|
||||||
|
computed_filename = None
|
||||||
|
try:
|
||||||
|
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||||
|
if p:
|
||||||
|
computed_filename = compute_relative_filename(p)
|
||||||
|
except Exception:
|
||||||
|
computed_filename = None
|
||||||
|
if computed_filename:
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
if new_meta:
|
||||||
|
replace_asset_info_metadata_projection(
|
||||||
|
session,
|
||||||
|
asset_info_id=info.id,
|
||||||
|
user_metadata=new_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
if tags is not None:
|
||||||
|
set_asset_info_tags(
|
||||||
|
session,
|
||||||
|
asset_info_id=info.id,
|
||||||
|
tags=tags,
|
||||||
|
origin=tag_origin,
|
||||||
|
)
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def set_asset_info_tags(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: Sequence[str],
|
||||||
|
origin: str = "manual",
|
||||||
|
) -> dict:
|
||||||
|
desired = normalize_tags(tags)
|
||||||
|
|
||||||
|
current = set(
|
||||||
|
tag_name for (tag_name,) in (
|
||||||
|
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
|
||||||
|
to_add = [t for t in desired if t not in current]
|
||||||
|
to_remove = [t for t in current if t not in desired]
|
||||||
|
|
||||||
|
if to_add:
|
||||||
|
ensure_tags_exist(session, to_add, tag_type="user")
|
||||||
|
session.add_all([
|
||||||
|
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||||
|
for t in to_add
|
||||||
|
])
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
session.execute(
|
||||||
|
delete(AssetInfoTag)
|
||||||
|
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||||
|
)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||||
|
|
||||||
|
|
||||||
|
def replace_asset_info_metadata_projection(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
info = session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
info.user_metadata = user_metadata or {}
|
||||||
|
info.updated_at = utcnow()
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
if not user_metadata:
|
||||||
|
return
|
||||||
|
|
||||||
|
rows: list[AssetInfoMeta] = []
|
||||||
|
for k, v in user_metadata.items():
|
||||||
|
for r in project_kv(k, v):
|
||||||
|
rows.append(
|
||||||
|
AssetInfoMeta(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
key=r["key"],
|
||||||
|
ordinal=int(r["ordinal"]),
|
||||||
|
val_str=r.get("val_str"),
|
||||||
|
val_num=r.get("val_num"),
|
||||||
|
val_bool=r.get("val_bool"),
|
||||||
|
val_json=r.get("val_json"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if rows:
|
||||||
|
session.add_all(rows)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
|
||||||
|
def ingest_fs_asset(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_hash: str,
|
||||||
|
abs_path: str,
|
||||||
|
size_bytes: int,
|
||||||
|
mtime_ns: int,
|
||||||
|
mime_type: str | None = None,
|
||||||
|
info_name: str | None = None,
|
||||||
|
owner_id: str = "",
|
||||||
|
preview_id: str | None = None,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
tags: Sequence[str] = (),
|
||||||
|
tag_origin: str = "manual",
|
||||||
|
require_existing_tags: bool = False,
|
||||||
|
) -> dict:
|
||||||
|
"""
|
||||||
|
Idempotently upsert:
|
||||||
|
- Asset by content hash (create if missing)
|
||||||
|
- AssetCacheState(file_path) pointing to asset_id
|
||||||
|
- Optionally AssetInfo + tag links and metadata projection
|
||||||
|
Returns flags and ids.
|
||||||
|
"""
|
||||||
|
locator = os.path.abspath(abs_path)
|
||||||
|
now = utcnow()
|
||||||
|
|
||||||
|
if preview_id:
|
||||||
|
if not session.get(Asset, preview_id):
|
||||||
|
preview_id = None
|
||||||
|
|
||||||
|
out: dict[str, Any] = {
|
||||||
|
"asset_created": False,
|
||||||
|
"asset_updated": False,
|
||||||
|
"state_created": False,
|
||||||
|
"state_updated": False,
|
||||||
|
"asset_info_id": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
# 1) Asset by hash
|
||||||
|
asset = (
|
||||||
|
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||||
|
).scalars().first()
|
||||||
|
if not asset:
|
||||||
|
vals = {
|
||||||
|
"hash": asset_hash,
|
||||||
|
"size_bytes": int(size_bytes),
|
||||||
|
"mime_type": mime_type,
|
||||||
|
"created_at": now,
|
||||||
|
}
|
||||||
|
res = session.execute(
|
||||||
|
sqlite.insert(Asset)
|
||||||
|
.values(**vals)
|
||||||
|
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||||
|
)
|
||||||
|
if int(res.rowcount or 0) > 0:
|
||||||
|
out["asset_created"] = True
|
||||||
|
asset = (
|
||||||
|
session.execute(
|
||||||
|
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||||
|
)
|
||||||
|
).scalars().first()
|
||||||
|
if not asset:
|
||||||
|
raise RuntimeError("Asset row not found after upsert.")
|
||||||
|
else:
|
||||||
|
changed = False
|
||||||
|
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||||
|
asset.size_bytes = int(size_bytes)
|
||||||
|
changed = True
|
||||||
|
if mime_type and asset.mime_type != mime_type:
|
||||||
|
asset.mime_type = mime_type
|
||||||
|
changed = True
|
||||||
|
if changed:
|
||||||
|
out["asset_updated"] = True
|
||||||
|
|
||||||
|
# 2) AssetCacheState upsert by file_path (unique)
|
||||||
|
vals = {
|
||||||
|
"asset_id": asset.id,
|
||||||
|
"file_path": locator,
|
||||||
|
"mtime_ns": int(mtime_ns),
|
||||||
|
}
|
||||||
|
ins = (
|
||||||
|
sqlite.insert(AssetCacheState)
|
||||||
|
.values(**vals)
|
||||||
|
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||||
|
)
|
||||||
|
|
||||||
|
res = session.execute(ins)
|
||||||
|
if int(res.rowcount or 0) > 0:
|
||||||
|
out["state_created"] = True
|
||||||
|
else:
|
||||||
|
upd = (
|
||||||
|
sa.update(AssetCacheState)
|
||||||
|
.where(AssetCacheState.file_path == locator)
|
||||||
|
.where(
|
||||||
|
sa.or_(
|
||||||
|
AssetCacheState.asset_id != asset.id,
|
||||||
|
AssetCacheState.mtime_ns.is_(None),
|
||||||
|
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||||
|
)
|
||||||
|
res2 = session.execute(upd)
|
||||||
|
if int(res2.rowcount or 0) > 0:
|
||||||
|
out["state_updated"] = True
|
||||||
|
|
||||||
|
# 3) Optional AssetInfo + tags + metadata
|
||||||
|
if info_name:
|
||||||
|
try:
|
||||||
|
with session.begin_nested():
|
||||||
|
info = AssetInfo(
|
||||||
|
owner_id=owner_id,
|
||||||
|
name=info_name,
|
||||||
|
asset_id=asset.id,
|
||||||
|
preview_id=preview_id,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
last_access_time=now,
|
||||||
|
)
|
||||||
|
session.add(info)
|
||||||
|
session.flush()
|
||||||
|
out["asset_info_id"] = info.id
|
||||||
|
except IntegrityError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
existing_info = (
|
||||||
|
session.execute(
|
||||||
|
select(AssetInfo)
|
||||||
|
.where(
|
||||||
|
AssetInfo.asset_id == asset.id,
|
||||||
|
AssetInfo.name == info_name,
|
||||||
|
(AssetInfo.owner_id == owner_id),
|
||||||
|
)
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
).unique().scalar_one_or_none()
|
||||||
|
if not existing_info:
|
||||||
|
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||||
|
|
||||||
|
if preview_id and existing_info.preview_id != preview_id:
|
||||||
|
existing_info.preview_id = preview_id
|
||||||
|
|
||||||
|
existing_info.updated_at = now
|
||||||
|
if existing_info.last_access_time < now:
|
||||||
|
existing_info.last_access_time = now
|
||||||
|
session.flush()
|
||||||
|
out["asset_info_id"] = existing_info.id
|
||||||
|
|
||||||
|
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||||
|
if norm and out["asset_info_id"] is not None:
|
||||||
|
if not require_existing_tags:
|
||||||
|
ensure_tags_exist(session, norm, tag_type="user")
|
||||||
|
|
||||||
|
existing_tag_names = set(
|
||||||
|
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||||
|
)
|
||||||
|
missing = [t for t in norm if t not in existing_tag_names]
|
||||||
|
if missing and require_existing_tags:
|
||||||
|
raise ValueError(f"Unknown tags: {missing}")
|
||||||
|
|
||||||
|
existing_links = set(
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
session.execute(
|
||||||
|
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
)
|
||||||
|
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||||
|
if to_add:
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
AssetInfoTag(
|
||||||
|
asset_info_id=out["asset_info_id"],
|
||||||
|
tag_name=t,
|
||||||
|
origin=tag_origin,
|
||||||
|
added_at=now,
|
||||||
|
)
|
||||||
|
for t in to_add
|
||||||
|
]
|
||||||
|
)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
# metadata["filename"] hack
|
||||||
|
if out["asset_info_id"] is not None:
|
||||||
|
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||||
|
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||||
|
|
||||||
|
current_meta = existing_info.user_metadata or {}
|
||||||
|
new_meta = dict(current_meta)
|
||||||
|
if user_metadata is not None:
|
||||||
|
for k, v in user_metadata.items():
|
||||||
|
new_meta[k] = v
|
||||||
|
if computed_filename:
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
|
||||||
|
if new_meta != current_meta:
|
||||||
|
replace_asset_info_metadata_projection(
|
||||||
|
session,
|
||||||
|
asset_info_id=out["asset_info_id"],
|
||||||
|
user_metadata=new_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||||
|
except Exception:
|
||||||
|
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||||
|
return out
|
||||||
|
|
||||||
|
|
||||||
|
def update_asset_info_full(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
tags: Sequence[str] | None = None,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
tag_origin: str = "manual",
|
||||||
|
asset_info_row: Any = None,
|
||||||
|
) -> AssetInfo:
|
||||||
|
if not asset_info_row:
|
||||||
|
info = session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
else:
|
||||||
|
info = asset_info_row
|
||||||
|
|
||||||
|
touched = False
|
||||||
|
if name is not None and name != info.name:
|
||||||
|
info.name = name
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
computed_filename = None
|
||||||
|
try:
|
||||||
|
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||||
|
if p:
|
||||||
|
computed_filename = compute_relative_filename(p)
|
||||||
|
except Exception:
|
||||||
|
computed_filename = None
|
||||||
|
|
||||||
|
if user_metadata is not None:
|
||||||
|
new_meta = dict(user_metadata)
|
||||||
|
if computed_filename:
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
replace_asset_info_metadata_projection(
|
||||||
|
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
else:
|
||||||
|
if computed_filename:
|
||||||
|
current_meta = info.user_metadata or {}
|
||||||
|
if current_meta.get("filename") != computed_filename:
|
||||||
|
new_meta = dict(current_meta)
|
||||||
|
new_meta["filename"] = computed_filename
|
||||||
|
replace_asset_info_metadata_projection(
|
||||||
|
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
if tags is not None:
|
||||||
|
set_asset_info_tags(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=tags,
|
||||||
|
origin=tag_origin,
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
if touched and user_metadata is None:
|
||||||
|
info.updated_at = utcnow()
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
return info
|
||||||
|
|
||||||
|
|
||||||
|
def delete_asset_info_by_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
owner_id: str,
|
||||||
|
) -> bool:
|
||||||
|
stmt = sa.delete(AssetInfo).where(
|
||||||
|
AssetInfo.id == asset_info_id,
|
||||||
|
visible_owner_clause(owner_id),
|
||||||
|
)
|
||||||
|
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||||
|
|
||||||
|
|
||||||
def list_tags_with_usage(
|
def list_tags_with_usage(
|
||||||
session: Session,
|
session: Session,
|
||||||
prefix: str | None = None,
|
prefix: str | None = None,
|
||||||
@ -265,3 +814,163 @@ def list_tags_with_usage(
|
|||||||
|
|
||||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||||
return rows_norm, int(total or 0)
|
return rows_norm, int(total or 0)
|
||||||
|
|
||||||
|
|
||||||
|
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||||
|
wanted = normalize_tags(list(names))
|
||||||
|
if not wanted:
|
||||||
|
return
|
||||||
|
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||||
|
ins = (
|
||||||
|
sqlite.insert(Tag)
|
||||||
|
.values(rows)
|
||||||
|
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||||
|
)
|
||||||
|
session.execute(ins)
|
||||||
|
|
||||||
|
|
||||||
|
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||||
|
return [
|
||||||
|
tag_name for (tag_name,) in (
|
||||||
|
session.execute(
|
||||||
|
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def add_tags_to_asset_info(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: Sequence[str],
|
||||||
|
origin: str = "manual",
|
||||||
|
create_if_missing: bool = True,
|
||||||
|
asset_info_row: Any = None,
|
||||||
|
) -> dict:
|
||||||
|
if not asset_info_row:
|
||||||
|
info = session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
norm = normalize_tags(tags)
|
||||||
|
if not norm:
|
||||||
|
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"added": [], "already_present": [], "total_tags": total}
|
||||||
|
|
||||||
|
if create_if_missing:
|
||||||
|
ensure_tags_exist(session, norm, tag_type="user")
|
||||||
|
|
||||||
|
current = {
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
session.execute(
|
||||||
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
|
||||||
|
want = set(norm)
|
||||||
|
to_add = sorted(want - current)
|
||||||
|
|
||||||
|
if to_add:
|
||||||
|
with session.begin_nested() as nested:
|
||||||
|
try:
|
||||||
|
session.add_all(
|
||||||
|
[
|
||||||
|
AssetInfoTag(
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tag_name=t,
|
||||||
|
origin=origin,
|
||||||
|
added_at=utcnow(),
|
||||||
|
)
|
||||||
|
for t in to_add
|
||||||
|
]
|
||||||
|
)
|
||||||
|
session.flush()
|
||||||
|
except IntegrityError:
|
||||||
|
nested.rollback()
|
||||||
|
|
||||||
|
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||||
|
return {
|
||||||
|
"added": sorted(((after - current) & want)),
|
||||||
|
"already_present": sorted(want & current),
|
||||||
|
"total_tags": sorted(after),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_tags_from_asset_info(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: Sequence[str],
|
||||||
|
) -> dict:
|
||||||
|
info = session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
norm = normalize_tags(tags)
|
||||||
|
if not norm:
|
||||||
|
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"removed": [], "not_present": [], "total_tags": total}
|
||||||
|
|
||||||
|
existing = {
|
||||||
|
tag_name
|
||||||
|
for (tag_name,) in (
|
||||||
|
session.execute(
|
||||||
|
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
}
|
||||||
|
|
||||||
|
to_remove = sorted(set(t for t in norm if t in existing))
|
||||||
|
not_present = sorted(set(t for t in norm if t not in existing))
|
||||||
|
|
||||||
|
if to_remove:
|
||||||
|
session.execute(
|
||||||
|
delete(AssetInfoTag)
|
||||||
|
.where(
|
||||||
|
AssetInfoTag.asset_info_id == asset_info_id,
|
||||||
|
AssetInfoTag.tag_name.in_(to_remove),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||||
|
|
||||||
|
|
||||||
|
def remove_missing_tag_for_asset_id(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_id: str,
|
||||||
|
) -> None:
|
||||||
|
session.execute(
|
||||||
|
sa.delete(AssetInfoTag).where(
|
||||||
|
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||||
|
AssetInfoTag.tag_name == "missing",
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_asset_info_preview(
|
||||||
|
session: Session,
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
preview_asset_id: str | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||||
|
info = session.get(AssetInfo, asset_info_id)
|
||||||
|
if not info:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
if preview_asset_id is None:
|
||||||
|
info.preview_id = None
|
||||||
|
else:
|
||||||
|
# validate preview asset exists
|
||||||
|
if not session.get(Asset, preview_asset_id):
|
||||||
|
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||||
|
info.preview_id = preview_asset_id
|
||||||
|
|
||||||
|
info.updated_at = utcnow()
|
||||||
|
session.flush()
|
||||||
|
|||||||
@ -1,5 +1,6 @@
|
|||||||
import contextlib
|
import contextlib
|
||||||
import os
|
import os
|
||||||
|
from decimal import Decimal
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from datetime import datetime, timezone
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@ -87,6 +88,40 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
|||||||
targets.append((name, paths))
|
targets.append((name, paths))
|
||||||
return targets
|
return targets
|
||||||
|
|
||||||
|
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||||
|
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||||
|
root = tags[0]
|
||||||
|
if root == "models":
|
||||||
|
if len(tags) < 2:
|
||||||
|
raise ValueError("at least two tags required for model asset")
|
||||||
|
try:
|
||||||
|
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||||
|
if not bases:
|
||||||
|
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||||
|
base_dir = os.path.abspath(bases[0])
|
||||||
|
raw_subdirs = tags[2:]
|
||||||
|
else:
|
||||||
|
base_dir = os.path.abspath(
|
||||||
|
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||||
|
)
|
||||||
|
raw_subdirs = tags[1:]
|
||||||
|
for i in raw_subdirs:
|
||||||
|
if i in (".", ".."):
|
||||||
|
raise ValueError("invalid path component in tags")
|
||||||
|
|
||||||
|
return base_dir, raw_subdirs if raw_subdirs else []
|
||||||
|
|
||||||
|
def ensure_within_base(candidate: str, base: str) -> None:
|
||||||
|
cand_abs = os.path.abspath(candidate)
|
||||||
|
base_abs = os.path.abspath(base)
|
||||||
|
try:
|
||||||
|
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||||
|
raise ValueError("destination escapes base directory")
|
||||||
|
except Exception:
|
||||||
|
raise ValueError("invalid destination path")
|
||||||
|
|
||||||
def compute_relative_filename(file_path: str) -> str | None:
|
def compute_relative_filename(file_path: str) -> str | None:
|
||||||
"""
|
"""
|
||||||
Return the model's path relative to the last well-known folder (the model category),
|
Return the model's path relative to the last well-known folder (the model category),
|
||||||
@ -113,7 +148,6 @@ def compute_relative_filename(file_path: str) -> str | None:
|
|||||||
return "/".join(inside)
|
return "/".join(inside)
|
||||||
return "/".join(parts) # input/output: keep all parts
|
return "/".join(parts) # input/output: keep all parts
|
||||||
|
|
||||||
|
|
||||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||||
@ -215,3 +249,64 @@ def collect_models_files() -> list[str]:
|
|||||||
if allowed:
|
if allowed:
|
||||||
out.append(abs_path)
|
out.append(abs_path)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def is_scalar(v):
|
||||||
|
if v is None:
|
||||||
|
return True
|
||||||
|
if isinstance(v, bool):
|
||||||
|
return True
|
||||||
|
if isinstance(v, (int, float, Decimal, str)):
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
def project_kv(key: str, value):
|
||||||
|
"""
|
||||||
|
Turn a metadata key/value into typed projection rows.
|
||||||
|
Returns list[dict] with keys:
|
||||||
|
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||||
|
"""
|
||||||
|
rows: list[dict] = []
|
||||||
|
|
||||||
|
def _null_row(ordinal: int) -> dict:
|
||||||
|
return {
|
||||||
|
"key": key, "ordinal": ordinal,
|
||||||
|
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||||
|
}
|
||||||
|
|
||||||
|
if value is None:
|
||||||
|
rows.append(_null_row(0))
|
||||||
|
return rows
|
||||||
|
|
||||||
|
if is_scalar(value):
|
||||||
|
if isinstance(value, bool):
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||||
|
elif isinstance(value, (int, float, Decimal)):
|
||||||
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||||
|
elif isinstance(value, str):
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||||
|
else:
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||||
|
return rows
|
||||||
|
|
||||||
|
if isinstance(value, list):
|
||||||
|
if all(is_scalar(x) for x in value):
|
||||||
|
for i, x in enumerate(value):
|
||||||
|
if x is None:
|
||||||
|
rows.append(_null_row(i))
|
||||||
|
elif isinstance(x, bool):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||||
|
elif isinstance(x, (int, float, Decimal)):
|
||||||
|
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||||
|
elif isinstance(x, str):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||||
|
else:
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||||
|
return rows
|
||||||
|
for i, x in enumerate(value):
|
||||||
|
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||||
|
return rows
|
||||||
|
|
||||||
|
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||||
|
return rows
|
||||||
|
|||||||
@ -1,13 +1,33 @@
|
|||||||
|
import os
|
||||||
|
import mimetypes
|
||||||
|
import contextlib
|
||||||
from typing import Sequence
|
from typing import Sequence
|
||||||
|
|
||||||
from app.database.db import create_session
|
from app.database.db import create_session
|
||||||
from app.assets.api import schemas_out
|
from app.assets.api import schemas_out, schemas_in
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
asset_exists_by_hash,
|
asset_exists_by_hash,
|
||||||
|
asset_info_exists_for_asset_id,
|
||||||
|
get_asset_by_hash,
|
||||||
|
get_asset_info_by_id,
|
||||||
fetch_asset_info_asset_and_tags,
|
fetch_asset_info_asset_and_tags,
|
||||||
|
fetch_asset_info_and_asset,
|
||||||
|
create_asset_info_for_existing_asset,
|
||||||
|
touch_asset_info_by_id,
|
||||||
|
update_asset_info_full,
|
||||||
|
delete_asset_info_by_id,
|
||||||
|
list_cache_states_by_asset_id,
|
||||||
list_asset_infos_page,
|
list_asset_infos_page,
|
||||||
list_tags_with_usage,
|
list_tags_with_usage,
|
||||||
|
get_asset_tags,
|
||||||
|
add_tags_to_asset_info,
|
||||||
|
remove_tags_from_asset_info,
|
||||||
|
pick_best_live_path,
|
||||||
|
ingest_fs_asset,
|
||||||
|
set_asset_info_preview,
|
||||||
)
|
)
|
||||||
|
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||||
|
from app.assets.database.models import Asset
|
||||||
|
|
||||||
|
|
||||||
def _safe_sort_field(requested: str | None) -> str:
|
def _safe_sort_field(requested: str | None) -> str:
|
||||||
@ -19,11 +39,28 @@ def _safe_sort_field(requested: str | None) -> str:
|
|||||||
return "created_at"
|
return "created_at"
|
||||||
|
|
||||||
|
|
||||||
def asset_exists(asset_hash: str) -> bool:
|
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||||
|
st = os.stat(path, follow_symlinks=True)
|
||||||
|
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||||
|
|
||||||
|
|
||||||
|
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||||
|
n = os.path.basename((name or "").strip() or fallback)
|
||||||
|
if n:
|
||||||
|
return n
|
||||||
|
return fallback
|
||||||
|
|
||||||
|
|
||||||
|
def asset_exists(*, asset_hash: str) -> bool:
|
||||||
|
"""
|
||||||
|
Check if an asset with a given hash exists in database.
|
||||||
|
"""
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||||
|
|
||||||
|
|
||||||
def list_assets(
|
def list_assets(
|
||||||
|
*,
|
||||||
include_tags: Sequence[str] | None = None,
|
include_tags: Sequence[str] | None = None,
|
||||||
exclude_tags: Sequence[str] | None = None,
|
exclude_tags: Sequence[str] | None = None,
|
||||||
name_contains: str | None = None,
|
name_contains: str | None = None,
|
||||||
@ -63,7 +100,6 @@ def list_assets(
|
|||||||
size=int(asset.size_bytes) if asset else None,
|
size=int(asset.size_bytes) if asset else None,
|
||||||
mime_type=asset.mime_type if asset else None,
|
mime_type=asset.mime_type if asset else None,
|
||||||
tags=tags,
|
tags=tags,
|
||||||
preview_url=f"/api/assets/{info.id}/content",
|
|
||||||
created_at=info.created_at,
|
created_at=info.created_at,
|
||||||
updated_at=info.updated_at,
|
updated_at=info.updated_at,
|
||||||
last_access_time=info.last_access_time,
|
last_access_time=info.last_access_time,
|
||||||
@ -76,7 +112,12 @@ def list_assets(
|
|||||||
has_more=(offset + len(summaries)) < total,
|
has_more=(offset + len(summaries)) < total,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail:
|
|
||||||
|
def get_asset(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.AssetDetail:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
if not res:
|
if not res:
|
||||||
@ -97,6 +138,358 @@ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail
|
|||||||
last_access_time=info.last_access_time,
|
last_access_time=info.last_access_time,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_asset_content_for_download(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> tuple[str, str, str]:
|
||||||
|
with create_session() as session:
|
||||||
|
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
|
if not pair:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
|
||||||
|
info, asset = pair
|
||||||
|
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||||
|
abs_path = pick_best_live_path(states)
|
||||||
|
if not abs_path:
|
||||||
|
raise FileNotFoundError
|
||||||
|
|
||||||
|
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||||
|
download_name = info.name or os.path.basename(abs_path)
|
||||||
|
return abs_path, ctype, download_name
|
||||||
|
|
||||||
|
|
||||||
|
def upload_asset_from_temp_path(
|
||||||
|
spec: schemas_in.UploadAssetSpec,
|
||||||
|
*,
|
||||||
|
temp_path: str,
|
||||||
|
client_filename: str | None = None,
|
||||||
|
owner_id: str = "",
|
||||||
|
expected_asset_hash: str | None = None,
|
||||||
|
) -> schemas_out.AssetCreated:
|
||||||
|
"""
|
||||||
|
Create new asset or update existing asset from a temporary file path.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||||
|
import app.assets.hashing as hashing
|
||||||
|
digest = hashing.blake3_hash(temp_path)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||||
|
asset_hash = "blake3:" + digest
|
||||||
|
|
||||||
|
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||||
|
raise ValueError("HASH_MISMATCH")
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||||
|
if existing is not None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
if temp_path and os.path.exists(temp_path):
|
||||||
|
os.remove(temp_path)
|
||||||
|
|
||||||
|
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||||
|
info = create_asset_info_for_existing_asset(
|
||||||
|
session,
|
||||||
|
asset_hash=asset_hash,
|
||||||
|
name=display_name,
|
||||||
|
user_metadata=spec.user_metadata or {},
|
||||||
|
tags=spec.tags or [],
|
||||||
|
tag_origin="manual",
|
||||||
|
owner_id=owner_id,
|
||||||
|
)
|
||||||
|
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return schemas_out.AssetCreated(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=existing.hash,
|
||||||
|
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||||
|
mime_type=existing.mime_type,
|
||||||
|
tags=tag_names,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_id=info.preview_id,
|
||||||
|
created_at=info.created_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
created_new=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||||
|
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||||
|
os.makedirs(dest_dir, exist_ok=True)
|
||||||
|
|
||||||
|
src_for_ext = (client_filename or spec.name or "").strip()
|
||||||
|
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||||
|
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||||
|
hashed_basename = f"{digest}{ext}"
|
||||||
|
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||||
|
ensure_within_base(dest_abs, base_dir)
|
||||||
|
|
||||||
|
content_type = (
|
||||||
|
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||||
|
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||||
|
or "application/octet-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
os.replace(temp_path, dest_abs)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||||
|
except OSError as e:
|
||||||
|
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
result = ingest_fs_asset(
|
||||||
|
session,
|
||||||
|
asset_hash=asset_hash,
|
||||||
|
abs_path=dest_abs,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mtime_ns=mtime_ns,
|
||||||
|
mime_type=content_type,
|
||||||
|
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||||
|
owner_id=owner_id,
|
||||||
|
preview_id=None,
|
||||||
|
user_metadata=spec.user_metadata or {},
|
||||||
|
tags=spec.tags,
|
||||||
|
tag_origin="manual",
|
||||||
|
require_existing_tags=False,
|
||||||
|
)
|
||||||
|
info_id = result["asset_info_id"]
|
||||||
|
if not info_id:
|
||||||
|
raise RuntimeError("failed to create asset metadata")
|
||||||
|
|
||||||
|
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||||
|
if not pair:
|
||||||
|
raise RuntimeError("inconsistent DB state after ingest")
|
||||||
|
info, asset = pair
|
||||||
|
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||||
|
created_result = schemas_out.AssetCreated(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=asset.hash,
|
||||||
|
size=int(asset.size_bytes),
|
||||||
|
mime_type=asset.mime_type,
|
||||||
|
tags=tag_names,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_id=info.preview_id,
|
||||||
|
created_at=info.created_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
created_new=result["asset_created"],
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return created_result
|
||||||
|
|
||||||
|
|
||||||
|
def update_asset(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
name: str | None = None,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.AssetUpdated:
|
||||||
|
with create_session() as session:
|
||||||
|
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
|
||||||
|
info = update_asset_info_full(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
name=name,
|
||||||
|
tags=tags,
|
||||||
|
user_metadata=user_metadata,
|
||||||
|
tag_origin="manual",
|
||||||
|
asset_info_row=info_row,
|
||||||
|
)
|
||||||
|
|
||||||
|
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||||
|
result = schemas_out.AssetUpdated(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=info.asset.hash if info.asset else None,
|
||||||
|
tags=tag_names,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
updated_at=info.updated_at,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def set_asset_preview(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
preview_asset_id: str | None = None,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.AssetDetail:
|
||||||
|
with create_session() as session:
|
||||||
|
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
|
||||||
|
set_asset_info_preview(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
preview_asset_id=preview_asset_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
|
if not res:
|
||||||
|
raise RuntimeError("State changed during preview update")
|
||||||
|
info, asset, tags = res
|
||||||
|
result = schemas_out.AssetDetail(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=asset.hash if asset else None,
|
||||||
|
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||||
|
mime_type=asset.mime_type if asset else None,
|
||||||
|
tags=tags,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_id=info.preview_id,
|
||||||
|
created_at=info.created_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||||
|
with create_session() as session:
|
||||||
|
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
asset_id = info_row.asset_id if info_row else None
|
||||||
|
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||||
|
if not deleted:
|
||||||
|
session.commit()
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not delete_content_if_orphan or not asset_id:
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||||
|
if still_exists:
|
||||||
|
session.commit()
|
||||||
|
return True
|
||||||
|
|
||||||
|
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||||
|
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||||
|
|
||||||
|
asset_row = session.get(Asset, asset_id)
|
||||||
|
if asset_row is not None:
|
||||||
|
session.delete(asset_row)
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
for p in file_paths:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
if p and os.path.isfile(p):
|
||||||
|
os.remove(p)
|
||||||
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
def create_asset_from_hash(
|
||||||
|
*,
|
||||||
|
hash_str: str,
|
||||||
|
name: str,
|
||||||
|
tags: list[str] | None = None,
|
||||||
|
user_metadata: dict | None = None,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.AssetCreated | None:
|
||||||
|
canonical = hash_str.strip().lower()
|
||||||
|
with create_session() as session:
|
||||||
|
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||||
|
if not asset:
|
||||||
|
return None
|
||||||
|
|
||||||
|
info = create_asset_info_for_existing_asset(
|
||||||
|
session,
|
||||||
|
asset_hash=canonical,
|
||||||
|
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||||
|
user_metadata=user_metadata or {},
|
||||||
|
tags=tags or [],
|
||||||
|
tag_origin="manual",
|
||||||
|
owner_id=owner_id,
|
||||||
|
)
|
||||||
|
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||||
|
result = schemas_out.AssetCreated(
|
||||||
|
id=info.id,
|
||||||
|
name=info.name,
|
||||||
|
asset_hash=asset.hash,
|
||||||
|
size=int(asset.size_bytes),
|
||||||
|
mime_type=asset.mime_type,
|
||||||
|
tags=tag_names,
|
||||||
|
user_metadata=info.user_metadata or {},
|
||||||
|
preview_id=info.preview_id,
|
||||||
|
created_at=info.created_at,
|
||||||
|
last_access_time=info.last_access_time,
|
||||||
|
created_new=False,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
|
def add_tags_to_asset(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: list[str],
|
||||||
|
origin: str = "manual",
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.TagsAdd:
|
||||||
|
with create_session() as session:
|
||||||
|
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
data = add_tags_to_asset_info(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=tags,
|
||||||
|
origin=origin,
|
||||||
|
create_if_missing=True,
|
||||||
|
asset_info_row=info_row,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
return schemas_out.TagsAdd(**data)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_tags_from_asset(
|
||||||
|
*,
|
||||||
|
asset_info_id: str,
|
||||||
|
tags: list[str],
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> schemas_out.TagsRemove:
|
||||||
|
with create_session() as session:
|
||||||
|
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||||
|
if not info_row:
|
||||||
|
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||||
|
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||||
|
raise PermissionError("not owner")
|
||||||
|
|
||||||
|
data = remove_tags_from_asset_info(
|
||||||
|
session,
|
||||||
|
asset_info_id=asset_info_id,
|
||||||
|
tags=tags,
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
return schemas_out.TagsRemove(**data)
|
||||||
|
|
||||||
|
|
||||||
def list_tags(
|
def list_tags(
|
||||||
prefix: str | None = None,
|
prefix: str | None = None,
|
||||||
limit: int = 100,
|
limit: int = 100,
|
||||||
|
|||||||
@ -27,6 +27,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
|||||||
t_start = time.perf_counter()
|
t_start = time.perf_counter()
|
||||||
created = 0
|
created = 0
|
||||||
skipped_existing = 0
|
skipped_existing = 0
|
||||||
|
orphans_pruned = 0
|
||||||
paths: list[str] = []
|
paths: list[str] = []
|
||||||
try:
|
try:
|
||||||
existing_paths: set[str] = set()
|
existing_paths: set[str] = set()
|
||||||
@ -38,6 +39,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
|||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||||
|
|
||||||
|
try:
|
||||||
|
orphans_pruned = _prune_orphaned_assets(roots)
|
||||||
|
except Exception as e:
|
||||||
|
logging.exception("orphan pruning failed: %s", e)
|
||||||
|
|
||||||
if "models" in roots:
|
if "models" in roots:
|
||||||
paths.extend(collect_models_files())
|
paths.extend(collect_models_files())
|
||||||
if "input" in roots:
|
if "input" in roots:
|
||||||
@ -85,15 +91,43 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No
|
|||||||
finally:
|
finally:
|
||||||
if enable_logging:
|
if enable_logging:
|
||||||
logging.info(
|
logging.info(
|
||||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)",
|
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||||
roots,
|
roots,
|
||||||
time.perf_counter() - t_start,
|
time.perf_counter() - t_start,
|
||||||
created,
|
created,
|
||||||
skipped_existing,
|
skipped_existing,
|
||||||
|
orphans_pruned,
|
||||||
len(paths),
|
len(paths),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||||
|
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||||
|
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||||
|
if not all_prefixes:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
def make_prefix_condition(prefix: str):
|
||||||
|
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||||
|
escaped, esc = escape_like_prefix(base)
|
||||||
|
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||||
|
|
||||||
|
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||||
|
|
||||||
|
orphan_subq = (
|
||||||
|
sqlalchemy.select(Asset.id)
|
||||||
|
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||||
|
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||||
|
).scalar_subquery()
|
||||||
|
|
||||||
|
with create_session() as sess:
|
||||||
|
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||||
|
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||||
|
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||||
|
sess.commit()
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
|
||||||
def _fast_db_consistency_pass(
|
def _fast_db_consistency_pass(
|
||||||
root: RootType,
|
root: RootType,
|
||||||
*,
|
*,
|
||||||
|
|||||||
271
tests-unit/assets_test/conftest.py
Normal file
271
tests-unit/assets_test/conftest.py
Normal file
@ -0,0 +1,271 @@
|
|||||||
|
import contextlib
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import socket
|
||||||
|
import subprocess
|
||||||
|
import sys
|
||||||
|
import tempfile
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Iterator, Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser: pytest.Parser) -> None:
|
||||||
|
"""
|
||||||
|
Allow overriding the database URL used by the spawned ComfyUI process.
|
||||||
|
Priority:
|
||||||
|
1) --db-url command line option
|
||||||
|
2) ASSETS_TEST_DB_URL environment variable (used by CI)
|
||||||
|
3) default: None (will use file-backed sqlite in temp dir)
|
||||||
|
"""
|
||||||
|
parser.addoption(
|
||||||
|
"--db-url",
|
||||||
|
action="store",
|
||||||
|
default=os.environ.get("ASSETS_TEST_DB_URL"),
|
||||||
|
help="SQLAlchemy DB URL (e.g. sqlite:///path/to/db.sqlite3)",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _free_port() -> int:
|
||||||
|
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
||||||
|
s.bind(("127.0.0.1", 0))
|
||||||
|
return s.getsockname()[1]
|
||||||
|
|
||||||
|
|
||||||
|
def _make_base_dirs(root: Path) -> None:
|
||||||
|
for sub in ("models", "custom_nodes", "input", "output", "temp", "user"):
|
||||||
|
(root / sub).mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
def _wait_http_ready(base: str, session: requests.Session, timeout: float = 90.0) -> None:
|
||||||
|
start = time.time()
|
||||||
|
last_err = None
|
||||||
|
while time.time() - start < timeout:
|
||||||
|
try:
|
||||||
|
r = session.get(base + "/api/assets", timeout=5)
|
||||||
|
if r.status_code in (200, 400):
|
||||||
|
return
|
||||||
|
except Exception as e:
|
||||||
|
last_err = e
|
||||||
|
time.sleep(0.25)
|
||||||
|
raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def comfy_tmp_base_dir() -> Path:
|
||||||
|
env_base = os.environ.get("ASSETS_TEST_BASE_DIR")
|
||||||
|
created_by_fixture = False
|
||||||
|
if env_base:
|
||||||
|
tmp = Path(env_base)
|
||||||
|
tmp.mkdir(parents=True, exist_ok=True)
|
||||||
|
else:
|
||||||
|
tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-"))
|
||||||
|
created_by_fixture = True
|
||||||
|
_make_base_dirs(tmp)
|
||||||
|
yield tmp
|
||||||
|
if created_by_fixture:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
for p in sorted(tmp.rglob("*"), reverse=True):
|
||||||
|
if p.is_file() or p.is_symlink():
|
||||||
|
p.unlink(missing_ok=True)
|
||||||
|
for p in sorted(tmp.glob("**/*"), reverse=True):
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
p.rmdir()
|
||||||
|
tmp.rmdir()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest):
|
||||||
|
"""
|
||||||
|
Boot ComfyUI subprocess with:
|
||||||
|
- sandbox base dir
|
||||||
|
- file-backed sqlite DB in temp dir
|
||||||
|
- autoscan disabled
|
||||||
|
Returns (base_url, process, port)
|
||||||
|
"""
|
||||||
|
port = _free_port()
|
||||||
|
db_url = request.config.getoption("--db-url")
|
||||||
|
if not db_url:
|
||||||
|
# Use a file-backed sqlite database in the temp directory
|
||||||
|
db_path = comfy_tmp_base_dir / "assets-test.sqlite3"
|
||||||
|
db_url = f"sqlite:///{db_path}"
|
||||||
|
|
||||||
|
logs_dir = comfy_tmp_base_dir / "logs"
|
||||||
|
logs_dir.mkdir(exist_ok=True)
|
||||||
|
out_log = open(logs_dir / "stdout.log", "w", buffering=1)
|
||||||
|
err_log = open(logs_dir / "stderr.log", "w", buffering=1)
|
||||||
|
|
||||||
|
comfy_root = Path(__file__).resolve().parent.parent.parent
|
||||||
|
if not (comfy_root / "main.py").is_file():
|
||||||
|
raise FileNotFoundError(f"main.py not found under {comfy_root}")
|
||||||
|
|
||||||
|
proc = subprocess.Popen(
|
||||||
|
args=[
|
||||||
|
sys.executable,
|
||||||
|
"main.py",
|
||||||
|
f"--base-directory={str(comfy_tmp_base_dir)}",
|
||||||
|
f"--database-url={db_url}",
|
||||||
|
"--disable-assets-autoscan",
|
||||||
|
"--listen",
|
||||||
|
"127.0.0.1",
|
||||||
|
"--port",
|
||||||
|
str(port),
|
||||||
|
"--cpu",
|
||||||
|
],
|
||||||
|
stdout=out_log,
|
||||||
|
stderr=err_log,
|
||||||
|
cwd=str(comfy_root),
|
||||||
|
env={**os.environ},
|
||||||
|
)
|
||||||
|
|
||||||
|
for _ in range(50):
|
||||||
|
if proc.poll() is not None:
|
||||||
|
out_log.flush()
|
||||||
|
err_log.flush()
|
||||||
|
raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}")
|
||||||
|
time.sleep(0.1)
|
||||||
|
|
||||||
|
base_url = f"http://127.0.0.1:{port}"
|
||||||
|
try:
|
||||||
|
with requests.Session() as s:
|
||||||
|
_wait_http_ready(base_url, s, timeout=90.0)
|
||||||
|
yield base_url, proc, port
|
||||||
|
except Exception as e:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
proc.terminate()
|
||||||
|
proc.wait(timeout=10)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
out_log.flush()
|
||||||
|
err_log.flush()
|
||||||
|
raise RuntimeError(f"ComfyUI did not become ready: {e}")
|
||||||
|
|
||||||
|
if proc and proc.poll() is None:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
proc.terminate()
|
||||||
|
proc.wait(timeout=15)
|
||||||
|
out_log.close()
|
||||||
|
err_log.close()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def http() -> Iterator[requests.Session]:
|
||||||
|
with requests.Session() as s:
|
||||||
|
s.timeout = 120
|
||||||
|
yield s
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def api_base(comfy_url_and_proc) -> str:
|
||||||
|
base_url, _proc, _port = comfy_url_and_proc
|
||||||
|
return base_url
|
||||||
|
|
||||||
|
|
||||||
|
def _post_multipart_asset(
|
||||||
|
session: requests.Session,
|
||||||
|
base: str,
|
||||||
|
*,
|
||||||
|
name: str,
|
||||||
|
tags: list[str],
|
||||||
|
meta: dict,
|
||||||
|
data: bytes,
|
||||||
|
extra_fields: Optional[dict] = None,
|
||||||
|
) -> tuple[int, dict]:
|
||||||
|
files = {"file": (name, data, "application/octet-stream")}
|
||||||
|
form_data = {
|
||||||
|
"tags": json.dumps(tags),
|
||||||
|
"name": name,
|
||||||
|
"user_metadata": json.dumps(meta),
|
||||||
|
}
|
||||||
|
if extra_fields:
|
||||||
|
for k, v in extra_fields.items():
|
||||||
|
form_data[k] = v
|
||||||
|
r = session.post(base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||||
|
return r.status_code, r.json()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def make_asset_bytes() -> Callable[[str, int], bytes]:
|
||||||
|
def _make(name: str, size: int = 8192) -> bytes:
|
||||||
|
seed = sum(ord(c) for c in name) % 251
|
||||||
|
return bytes((i * 31 + seed) % 256 for i in range(size))
|
||||||
|
return _make
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def asset_factory(http: requests.Session, api_base: str):
|
||||||
|
"""
|
||||||
|
Returns create(name, tags, meta, data) -> response dict
|
||||||
|
Tracks created ids and deletes them after the test.
|
||||||
|
"""
|
||||||
|
created: list[str] = []
|
||||||
|
|
||||||
|
def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict:
|
||||||
|
status, body = _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data)
|
||||||
|
assert status in (200, 201), body
|
||||||
|
created.append(body["id"])
|
||||||
|
return body
|
||||||
|
|
||||||
|
yield create
|
||||||
|
|
||||||
|
for aid in created:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_base: str) -> dict:
|
||||||
|
"""
|
||||||
|
Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/<name>.
|
||||||
|
Returns response dict with id, asset_hash, tags, etc.
|
||||||
|
"""
|
||||||
|
name = "unit_1_example.safetensors"
|
||||||
|
p = getattr(request, "param", {}) or {}
|
||||||
|
tags: Optional[list[str]] = p.get("tags")
|
||||||
|
if tags is None:
|
||||||
|
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||||
|
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
|
||||||
|
files = {"file": (name, b"A" * 4096, "application/octet-stream")}
|
||||||
|
form_data = {
|
||||||
|
"tags": json.dumps(tags),
|
||||||
|
"name": name,
|
||||||
|
"user_metadata": json.dumps(meta),
|
||||||
|
}
|
||||||
|
r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 201, body
|
||||||
|
return body
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def autoclean_unit_test_assets(http: requests.Session, api_base: str):
|
||||||
|
"""Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test."""
|
||||||
|
yield
|
||||||
|
|
||||||
|
while True:
|
||||||
|
r = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests", "limit": "500", "sort": "name"},
|
||||||
|
timeout=30,
|
||||||
|
)
|
||||||
|
if r.status_code != 200:
|
||||||
|
break
|
||||||
|
body = r.json()
|
||||||
|
ids = [a["id"] for a in body.get("assets", [])]
|
||||||
|
if not ids:
|
||||||
|
break
|
||||||
|
for aid in ids:
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||||
|
|
||||||
|
|
||||||
|
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
|
||||||
|
"""Force a fast sync/seed pass by calling the seed endpoint."""
|
||||||
|
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
|
||||||
|
time.sleep(0.2)
|
||||||
|
|
||||||
|
|
||||||
|
def get_asset_filename(asset_hash: str, extension: str) -> str:
|
||||||
|
return asset_hash.removeprefix("blake3:") + extension
|
||||||
348
tests-unit/assets_test/test_assets_missing_sync.py
Normal file
348
tests-unit/assets_test/test_assets_missing_sync.py
Normal file
@ -0,0 +1,348 @@
|
|||||||
|
import os
|
||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
def test_seed_asset_removed_when_file_is_deleted(
|
||||||
|
root: str,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
):
|
||||||
|
"""Asset without hash (seed) whose file disappears:
|
||||||
|
after triggering sync_seed_assets, Asset + AssetInfo disappear.
|
||||||
|
"""
|
||||||
|
# Create a file directly under input/unit-tests/<case> so tags include "unit-tests"
|
||||||
|
case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed"
|
||||||
|
case_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
name = f"seed_{uuid.uuid4().hex[:8]}.bin"
|
||||||
|
fp = case_dir / name
|
||||||
|
fp.write_bytes(b"Z" * 2048)
|
||||||
|
|
||||||
|
# Trigger a seed sync so DB sees this path (seed asset => hash is NULL)
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
# Verify it is visible via API and carries no hash (seed)
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
body1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
# there should be exactly one with that name
|
||||||
|
matches = [a for a in body1.get("assets", []) if a.get("name") == name]
|
||||||
|
assert matches
|
||||||
|
assert matches[0].get("asset_hash") is None
|
||||||
|
asset_info_id = matches[0]["id"]
|
||||||
|
|
||||||
|
# Remove the underlying file and sync again
|
||||||
|
if fp.exists():
|
||||||
|
fp.unlink()
|
||||||
|
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
# It should disappear (AssetInfo and seed Asset gone)
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,syncseed", "name_contains": name},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
body2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
matches2 = [a for a in body2.get("assets", []) if a.get("name") == name]
|
||||||
|
assert not matches2, f"Seed asset {asset_info_id} should be gone after sync"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Requires computing hashes of files in directories to verify and clear missing tags")
|
||||||
|
def test_hashed_asset_missing_tag_added_then_removed_after_scan(
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
):
|
||||||
|
"""Hashed asset with a single cache_state:
|
||||||
|
1. delete its file -> sync adds 'missing'
|
||||||
|
2. restore file -> sync removes 'missing'
|
||||||
|
"""
|
||||||
|
name = "missing_tag_test.png"
|
||||||
|
tags = ["input", "unit-tests", "msync2"]
|
||||||
|
data = make_asset_bytes(name, 4096)
|
||||||
|
a = asset_factory(name, tags, {}, data)
|
||||||
|
|
||||||
|
# Compute its on-disk path and remove it
|
||||||
|
dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png")
|
||||||
|
assert dest.exists(), f"Expected asset file at {dest}"
|
||||||
|
dest.unlink()
|
||||||
|
|
||||||
|
# Fast sync should add 'missing' to the AssetInfo
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
g1 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
|
||||||
|
d1 = g1.json()
|
||||||
|
assert g1.status_code == 200, d1
|
||||||
|
assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion"
|
||||||
|
|
||||||
|
# Restore the file with the exact same content and sync again
|
||||||
|
dest.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
dest.write_bytes(data)
|
||||||
|
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
g2 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
|
||||||
|
d2 = g2.json()
|
||||||
|
assert g2.status_code == 200, d2
|
||||||
|
assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify"
|
||||||
|
|
||||||
|
|
||||||
|
def test_hashed_asset_two_asset_infos_both_get_missing(
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
asset_factory,
|
||||||
|
):
|
||||||
|
"""Hashed asset with a single cache_state, but two AssetInfo rows:
|
||||||
|
deleting the single file then syncing should add 'missing' to both infos.
|
||||||
|
"""
|
||||||
|
# Upload one hashed asset
|
||||||
|
name = "two_infos_one_path.png"
|
||||||
|
base_tags = ["input", "unit-tests", "multiinfo"]
|
||||||
|
created = asset_factory(name, base_tags, {}, b"A" * 2048)
|
||||||
|
|
||||||
|
# Create second AssetInfo for the same Asset via from-hash
|
||||||
|
payload = {
|
||||||
|
"hash": created["asset_hash"],
|
||||||
|
"name": "two_infos_one_path_copy.png",
|
||||||
|
"tags": base_tags, # keep it in our unit-tests scope for cleanup
|
||||||
|
"user_metadata": {"k": "v"},
|
||||||
|
}
|
||||||
|
r2 = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 201, b2
|
||||||
|
second_id = b2["id"]
|
||||||
|
|
||||||
|
# Remove the single underlying file
|
||||||
|
p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png")
|
||||||
|
assert p.exists()
|
||||||
|
p.unlink()
|
||||||
|
|
||||||
|
r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||||
|
tags0 = r0.json()
|
||||||
|
assert r0.status_code == 200, tags0
|
||||||
|
byname0 = {t["name"]: t for t in tags0.get("tags", [])}
|
||||||
|
old_missing = int(byname0.get("missing", {}).get("count", 0))
|
||||||
|
|
||||||
|
# Sync -> both AssetInfos for this asset must receive 'missing'
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
ga = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||||
|
da = ga.json()
|
||||||
|
assert ga.status_code == 200, da
|
||||||
|
assert "missing" in set(da.get("tags", []))
|
||||||
|
|
||||||
|
gb = http.get(f"{api_base}/api/assets/{second_id}", timeout=120)
|
||||||
|
db = gb.json()
|
||||||
|
assert gb.status_code == 200, db
|
||||||
|
assert "missing" in set(db.get("tags", []))
|
||||||
|
|
||||||
|
# Tag usage for 'missing' increased by exactly 2 (two AssetInfos)
|
||||||
|
r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||||
|
tags1 = r1.json()
|
||||||
|
assert r1.status_code == 200, tags1
|
||||||
|
byname1 = {t["name"]: t for t in tags1.get("tags", [])}
|
||||||
|
new_missing = int(byname1.get("missing", {}).get("count", 0))
|
||||||
|
assert new_missing == old_missing + 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||||
|
def test_hashed_asset_two_cache_states_partial_delete_then_full_delete(
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
run_scan_and_wait,
|
||||||
|
):
|
||||||
|
"""Hashed asset with two cache_state rows:
|
||||||
|
1. delete one file -> sync should NOT add 'missing'
|
||||||
|
2. delete second file -> sync should add 'missing'
|
||||||
|
"""
|
||||||
|
name = "two_cache_states_partial_delete.png"
|
||||||
|
tags = ["input", "unit-tests", "dual"]
|
||||||
|
data = make_asset_bytes(name, 3072)
|
||||||
|
|
||||||
|
created = asset_factory(name, tags, {}, data)
|
||||||
|
path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png")
|
||||||
|
assert path1.exists()
|
||||||
|
|
||||||
|
# Create a second on-disk copy under the same root but different subfolder
|
||||||
|
path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name
|
||||||
|
path2.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path2.write_bytes(data)
|
||||||
|
|
||||||
|
# Fast seed so the second path appears (as a seed initially)
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
# Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner.
|
||||||
|
run_scan_and_wait("input")
|
||||||
|
|
||||||
|
# Remove only one file and sync -> asset should still be healthy (no 'missing')
|
||||||
|
path1.unlink()
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
g1 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||||
|
d1 = g1.json()
|
||||||
|
assert g1.status_code == 200, d1
|
||||||
|
assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains"
|
||||||
|
|
||||||
|
# Baseline 'missing' usage count just before last file removal
|
||||||
|
r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||||
|
tags0 = r0.json()
|
||||||
|
assert r0.status_code == 200, tags0
|
||||||
|
old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0))
|
||||||
|
|
||||||
|
# Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo
|
||||||
|
path2.unlink()
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
g2 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120)
|
||||||
|
d2 = g2.json()
|
||||||
|
assert g2.status_code == 200, d2
|
||||||
|
assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain"
|
||||||
|
|
||||||
|
# Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset)
|
||||||
|
r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120)
|
||||||
|
tags1 = r1.json()
|
||||||
|
assert r1.status_code == 200, tags1
|
||||||
|
new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0))
|
||||||
|
assert new_missing == old_missing + 2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match(
|
||||||
|
root: str,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Fast pass alone clears 'missing' when size and mtime match exactly:
|
||||||
|
1) upload (hashed), record original mtime_ns
|
||||||
|
2) delete -> fast pass adds 'missing'
|
||||||
|
3) restore same bytes and set mtime back to the original value
|
||||||
|
4) run fast pass again -> 'missing' is removed (no slow scan)
|
||||||
|
"""
|
||||||
|
scope = f"fastclear-{uuid.uuid4().hex[:6]}"
|
||||||
|
name = "fastpass_clear.bin"
|
||||||
|
data = make_asset_bytes(name, 3072)
|
||||||
|
|
||||||
|
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||||
|
aid = a["id"]
|
||||||
|
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||||
|
p = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||||
|
st0 = p.stat()
|
||||||
|
orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000))
|
||||||
|
|
||||||
|
# Delete -> fast pass adds 'missing'
|
||||||
|
p.unlink()
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
d1 = g1.json()
|
||||||
|
assert g1.status_code == 200, d1
|
||||||
|
assert "missing" in set(d1.get("tags", []))
|
||||||
|
|
||||||
|
# Restore same bytes and revert mtime to the original value
|
||||||
|
p.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
p.write_bytes(data)
|
||||||
|
# set both atime and mtime in ns to ensure exact match
|
||||||
|
os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns))
|
||||||
|
|
||||||
|
# Fast pass should clear 'missing' without a scan
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
d2 = g2.json()
|
||||||
|
assert g2.status_code == 200, d2
|
||||||
|
assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
def test_fastpass_removes_stale_state_row_no_missing(
|
||||||
|
root: str,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
run_scan_and_wait,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hashed asset with two states:
|
||||||
|
- delete one file
|
||||||
|
- run fast pass only
|
||||||
|
Expect:
|
||||||
|
- asset stays healthy (no 'missing')
|
||||||
|
- stale AssetCacheState row for the deleted path is removed.
|
||||||
|
We verify this behaviorally by recreating the deleted path and running fast pass again:
|
||||||
|
a new *seed* AssetInfo is created, which proves the old state row was not reused.
|
||||||
|
"""
|
||||||
|
scope = f"stale-{uuid.uuid4().hex[:6]}"
|
||||||
|
name = "two_states.bin"
|
||||||
|
data = make_asset_bytes(name, 2048)
|
||||||
|
|
||||||
|
# Upload hashed asset at path1
|
||||||
|
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||||
|
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||||
|
a1_filename = get_asset_filename(a["asset_hash"], ".bin")
|
||||||
|
p1 = base / a1_filename
|
||||||
|
assert p1.exists()
|
||||||
|
|
||||||
|
aid = a["id"]
|
||||||
|
h = a["asset_hash"]
|
||||||
|
|
||||||
|
# Create second state path2, seed+scan to dedupe into the same Asset
|
||||||
|
p2 = base / "copy" / name
|
||||||
|
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
p2.write_bytes(data)
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
run_scan_and_wait(root)
|
||||||
|
|
||||||
|
# Delete path1 and run fast pass -> no 'missing' and stale state row should be removed
|
||||||
|
p1.unlink()
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
d1 = g1.json()
|
||||||
|
assert g1.status_code == 200, d1
|
||||||
|
assert "missing" not in set(d1.get("tags", []))
|
||||||
|
|
||||||
|
# Recreate path1 and run fast pass again.
|
||||||
|
# If the stale state row was removed, a NEW seed AssetInfo will appear for this path.
|
||||||
|
p1.write_bytes(data)
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
rl = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": f"unit-tests,{scope}"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
bl = rl.json()
|
||||||
|
assert rl.status_code == 200, bl
|
||||||
|
items = bl.get("assets", [])
|
||||||
|
# one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null)
|
||||||
|
hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)]
|
||||||
|
assert h in hashes
|
||||||
|
assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path"
|
||||||
|
|
||||||
|
# Asset identity still healthy
|
||||||
|
rh = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120)
|
||||||
|
assert rh.status_code == 200
|
||||||
306
tests-unit/assets_test/test_crud.py
Normal file
306
tests-unit/assets_test/test_crud.py
Normal file
@ -0,0 +1,306 @@
|
|||||||
|
import uuid
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_from_hash_success(
|
||||||
|
http: requests.Session, api_base: str, seeded_asset: dict
|
||||||
|
):
|
||||||
|
h = seeded_asset["asset_hash"]
|
||||||
|
payload = {
|
||||||
|
"hash": h,
|
||||||
|
"name": "from_hash_ok.safetensors",
|
||||||
|
"tags": ["models", "checkpoints", "unit-tests", "from-hash"],
|
||||||
|
"user_metadata": {"k": "v"},
|
||||||
|
}
|
||||||
|
r1 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 201, b1
|
||||||
|
assert b1["asset_hash"] == h
|
||||||
|
assert b1["created_new"] is False
|
||||||
|
aid = b1["id"]
|
||||||
|
|
||||||
|
# Calling again with the same name should return the same AssetInfo id
|
||||||
|
r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 201, b2
|
||||||
|
assert b2["id"] == aid
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||||
|
aid = seeded_asset["id"]
|
||||||
|
|
||||||
|
# GET detail
|
||||||
|
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
detail = rg.json()
|
||||||
|
assert rg.status_code == 200, detail
|
||||||
|
assert detail["id"] == aid
|
||||||
|
assert "user_metadata" in detail
|
||||||
|
assert "filename" in detail["user_metadata"]
|
||||||
|
|
||||||
|
# DELETE
|
||||||
|
rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
assert rd.status_code == 204
|
||||||
|
|
||||||
|
# GET again -> 404
|
||||||
|
rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
body = rg2.json()
|
||||||
|
assert rg2.status_code == 404
|
||||||
|
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_upon_reference_count(
|
||||||
|
http: requests.Session, api_base: str, seeded_asset: dict
|
||||||
|
):
|
||||||
|
# Create a second reference to the same asset via from-hash
|
||||||
|
src_hash = seeded_asset["asset_hash"]
|
||||||
|
payload = {
|
||||||
|
"hash": src_hash,
|
||||||
|
"name": "unit_ref_copy.safetensors",
|
||||||
|
"tags": ["models", "checkpoints", "unit-tests", "del-flow"],
|
||||||
|
"user_metadata": {"note": "copy"},
|
||||||
|
}
|
||||||
|
r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||||
|
copy = r2.json()
|
||||||
|
assert r2.status_code == 201, copy
|
||||||
|
assert copy["asset_hash"] == src_hash
|
||||||
|
assert copy["created_new"] is False
|
||||||
|
|
||||||
|
# Delete original reference -> asset identity must remain
|
||||||
|
aid1 = seeded_asset["id"]
|
||||||
|
rd1 = http.delete(f"{api_base}/api/assets/{aid1}", timeout=120)
|
||||||
|
assert rd1.status_code == 204
|
||||||
|
|
||||||
|
rh1 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
|
||||||
|
assert rh1.status_code == 200 # identity still present
|
||||||
|
|
||||||
|
# Delete the last reference with default semantics -> identity and cached files removed
|
||||||
|
aid2 = copy["id"]
|
||||||
|
rd2 = http.delete(f"{api_base}/api/assets/{aid2}", timeout=120)
|
||||||
|
assert rd2.status_code == 204
|
||||||
|
|
||||||
|
rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
|
||||||
|
assert rh2.status_code == 404 # orphan content removed
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||||
|
aid = seeded_asset["id"]
|
||||||
|
original_tags = seeded_asset["tags"]
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"name": "unit_1_renamed.safetensors",
|
||||||
|
"user_metadata": {"purpose": "updated", "epoch": 2},
|
||||||
|
}
|
||||||
|
ru = http.put(f"{api_base}/api/assets/{aid}", json=payload, timeout=120)
|
||||||
|
body = ru.json()
|
||||||
|
assert ru.status_code == 200, body
|
||||||
|
assert body["name"] == payload["name"]
|
||||||
|
assert body["tags"] == original_tags # tags unchanged
|
||||||
|
assert body["user_metadata"]["purpose"] == "updated"
|
||||||
|
# filename should still be present and normalized by server
|
||||||
|
assert "filename" in body["user_metadata"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_head_asset_by_hash(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||||
|
h = seeded_asset["asset_hash"]
|
||||||
|
|
||||||
|
# Existing
|
||||||
|
rh1 = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120)
|
||||||
|
assert rh1.status_code == 200
|
||||||
|
|
||||||
|
# Non-existent
|
||||||
|
rh2 = http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}", timeout=120)
|
||||||
|
assert rh2.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api_base: str):
|
||||||
|
# Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload.
|
||||||
|
# requests exposes an empty body for HEAD, so validate status and that there is no payload.
|
||||||
|
rh = http.head(f"{api_base}/api/assets/hash/not_a_hash", timeout=120)
|
||||||
|
assert rh.status_code == 400
|
||||||
|
body = rh.content
|
||||||
|
assert body == b""
|
||||||
|
|
||||||
|
|
||||||
|
def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str):
|
||||||
|
bogus = str(uuid.uuid4())
|
||||||
|
r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 404
|
||||||
|
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_from_hash_invalids(http: requests.Session, api_base: str):
|
||||||
|
# Bad hash algorithm
|
||||||
|
bad = {
|
||||||
|
"hash": "sha256:" + "0" * 64,
|
||||||
|
"name": "x.bin",
|
||||||
|
"tags": ["models", "checkpoints", "unit-tests"],
|
||||||
|
}
|
||||||
|
r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 400
|
||||||
|
assert b1["error"]["code"] == "INVALID_BODY"
|
||||||
|
|
||||||
|
# Invalid JSON body
|
||||||
|
r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 400
|
||||||
|
assert b2["error"]["code"] == "INVALID_JSON"
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_update_download_bad_ids(http: requests.Session, api_base: str):
|
||||||
|
# All endpoints should be not found, as we UUID regex directly in the route definition.
|
||||||
|
bad_id = "not-a-uuid"
|
||||||
|
|
||||||
|
r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120)
|
||||||
|
assert r1.status_code == 404
|
||||||
|
|
||||||
|
r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120)
|
||||||
|
assert r3.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||||
|
aid = seeded_asset["id"]
|
||||||
|
r = http.put(f"{api_base}/api/assets/{aid}", json={}, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 400
|
||||||
|
assert body["error"]["code"] == "INVALID_BODY"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
def test_concurrent_delete_same_asset_info_single_204(
|
||||||
|
root: str,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Many concurrent DELETE for the same AssetInfo should result in:
|
||||||
|
- exactly one 204 No Content (the one that actually deleted)
|
||||||
|
- all others 404 Not Found (row already gone)
|
||||||
|
"""
|
||||||
|
scope = f"conc-del-{uuid.uuid4().hex[:6]}"
|
||||||
|
name = "to_delete.bin"
|
||||||
|
data = make_asset_bytes(name, 1536)
|
||||||
|
|
||||||
|
created = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||||
|
aid = created["id"]
|
||||||
|
|
||||||
|
# Hit the same endpoint N times in parallel.
|
||||||
|
n_tests = 4
|
||||||
|
url = f"{api_base}/api/assets/{aid}?delete_content=false"
|
||||||
|
|
||||||
|
def _do_delete(delete_url):
|
||||||
|
with requests.Session() as s:
|
||||||
|
return s.delete(delete_url, timeout=120).status_code
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=n_tests) as ex:
|
||||||
|
statuses = list(ex.map(_do_delete, [url] * n_tests))
|
||||||
|
|
||||||
|
# Exactly one actual delete, the rest must be 404
|
||||||
|
assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}"
|
||||||
|
assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}"
|
||||||
|
|
||||||
|
# The resource must be gone.
|
||||||
|
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
assert rg.status_code == 404
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
def test_metadata_filename_is_set_for_seed_asset_without_hash(
|
||||||
|
root: str,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
):
|
||||||
|
"""Seed ingest (no hash yet) must compute user_metadata['filename'] immediately."""
|
||||||
|
scope = f"seedmeta-{uuid.uuid4().hex[:6]}"
|
||||||
|
name = "seed_filename.bin"
|
||||||
|
|
||||||
|
base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b"
|
||||||
|
base.mkdir(parents=True, exist_ok=True)
|
||||||
|
fp = base / name
|
||||||
|
fp.write_bytes(b"Z" * 2048)
|
||||||
|
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": f"unit-tests,{scope}", "name_contains": name},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
body = r1.json()
|
||||||
|
assert r1.status_code == 200, body
|
||||||
|
matches = [a for a in body.get("assets", []) if a.get("name") == name]
|
||||||
|
assert matches, "Seed asset should be visible after sync"
|
||||||
|
assert matches[0].get("asset_hash") is None # still a seed
|
||||||
|
aid = matches[0]["id"]
|
||||||
|
|
||||||
|
r2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
detail = r2.json()
|
||||||
|
assert r2.status_code == 200, detail
|
||||||
|
filename = (detail.get("user_metadata") or {}).get("filename")
|
||||||
|
expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/")
|
||||||
|
assert filename == expected, f"expected filename={expected}, got {filename!r}"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Requires computing hashes of files in directories to retarget cache states")
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
def test_metadata_filename_computed_and_updated_on_retarget(
|
||||||
|
root: str,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
run_scan_and_wait,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
1) Ingest under {root}/unit-tests/<scope>/a/b/<name> -> filename reflects relative path.
|
||||||
|
2) Retarget by copying to {root}/unit-tests/<scope>/x/<new_name>, remove old file,
|
||||||
|
run fast pass + scan -> filename updates to new relative path.
|
||||||
|
"""
|
||||||
|
scope = f"meta-fn-{uuid.uuid4().hex[:6]}"
|
||||||
|
name1 = "compute_metadata_filename.png"
|
||||||
|
name2 = "compute_changed_metadata_filename.png"
|
||||||
|
data = make_asset_bytes(name1, 2100)
|
||||||
|
|
||||||
|
# Upload into nested path a/b
|
||||||
|
a = asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data)
|
||||||
|
aid = a["id"]
|
||||||
|
|
||||||
|
root_base = comfy_tmp_base_dir / root
|
||||||
|
p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png"))
|
||||||
|
assert p1.exists()
|
||||||
|
|
||||||
|
# filename at ingest should be the path relative to root
|
||||||
|
rel1 = str(p1.relative_to(root_base)).replace("\\", "/")
|
||||||
|
g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
d1 = g1.json()
|
||||||
|
assert g1.status_code == 200, d1
|
||||||
|
fn1 = d1["user_metadata"].get("filename")
|
||||||
|
assert fn1 == rel1
|
||||||
|
|
||||||
|
# Retarget: copy to x/<name2>, remove old, then sync+scan
|
||||||
|
p2 = root_base / "unit-tests" / scope / "x" / name2
|
||||||
|
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
p2.write_bytes(data)
|
||||||
|
if p1.exists():
|
||||||
|
p1.unlink()
|
||||||
|
|
||||||
|
trigger_sync_seed_assets(http, api_base) # seed the new path
|
||||||
|
run_scan_and_wait(root) # verify/hash and reconcile
|
||||||
|
|
||||||
|
# filename should now point at x/<name2>
|
||||||
|
rel2 = str(p2.relative_to(root_base)).replace("\\", "/")
|
||||||
|
g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
d2 = g2.json()
|
||||||
|
assert g2.status_code == 200, d2
|
||||||
|
fn2 = d2["user_metadata"].get("filename")
|
||||||
|
assert fn2 == rel2
|
||||||
166
tests-unit/assets_test/test_downloads.py
Normal file
166
tests-unit/assets_test/test_downloads.py
Normal file
@ -0,0 +1,166 @@
|
|||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||||
|
|
||||||
|
|
||||||
|
def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||||
|
aid = seeded_asset["id"]
|
||||||
|
|
||||||
|
# default attachment
|
||||||
|
r1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||||
|
data = r1.content
|
||||||
|
assert r1.status_code == 200
|
||||||
|
cd = r1.headers.get("Content-Disposition", "")
|
||||||
|
assert "attachment" in cd
|
||||||
|
assert data and len(data) == 4096
|
||||||
|
|
||||||
|
# inline requested
|
||||||
|
r2 = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120)
|
||||||
|
r2.content
|
||||||
|
assert r2.status_code == 200
|
||||||
|
cd2 = r2.headers.get("Content-Disposition", "")
|
||||||
|
assert "inline" in cd2
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
def test_download_chooses_existing_state_and_updates_access_time(
|
||||||
|
root: str,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
run_scan_and_wait,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Hashed asset with two state paths: if the first one disappears,
|
||||||
|
GET /content still serves from the remaining path and bumps last_access_time.
|
||||||
|
"""
|
||||||
|
scope = f"dl-first-{uuid.uuid4().hex[:6]}"
|
||||||
|
name = "first_existing_state.bin"
|
||||||
|
data = make_asset_bytes(name, 3072)
|
||||||
|
|
||||||
|
# Upload -> path1
|
||||||
|
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||||
|
aid = a["id"]
|
||||||
|
|
||||||
|
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||||
|
path1 = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||||
|
assert path1.exists()
|
||||||
|
|
||||||
|
# Seed path2 by copying, then scan to dedupe into a second state
|
||||||
|
path2 = base / "alt" / name
|
||||||
|
path2.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path2.write_bytes(data)
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
run_scan_and_wait(root)
|
||||||
|
|
||||||
|
# Remove path1 so server must fall back to path2
|
||||||
|
path1.unlink()
|
||||||
|
|
||||||
|
# last_access_time before
|
||||||
|
rg0 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
d0 = rg0.json()
|
||||||
|
assert rg0.status_code == 200, d0
|
||||||
|
ts0 = d0.get("last_access_time")
|
||||||
|
|
||||||
|
time.sleep(0.05)
|
||||||
|
r = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||||
|
blob = r.content
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert blob == data # must serve from the surviving state (same bytes)
|
||||||
|
|
||||||
|
rg1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
d1 = rg1.json()
|
||||||
|
assert rg1.status_code == 200, d1
|
||||||
|
ts1 = d1.get("last_access_time")
|
||||||
|
|
||||||
|
def _parse_iso8601(s: Optional[str]) -> Optional[float]:
|
||||||
|
if not s:
|
||||||
|
return None
|
||||||
|
s = s[:-1] if s.endswith("Z") else s
|
||||||
|
return datetime.fromisoformat(s).timestamp()
|
||||||
|
|
||||||
|
t0 = _parse_iso8601(ts0)
|
||||||
|
t1 = _parse_iso8601(ts1)
|
||||||
|
assert t1 is not None
|
||||||
|
if t0 is not None:
|
||||||
|
assert t1 > t0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True)
|
||||||
|
def test_download_missing_file_returns_404(
|
||||||
|
http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict
|
||||||
|
):
|
||||||
|
# Remove the underlying file then attempt download.
|
||||||
|
# We initialize fixture without additional tags to know exactly the asset file path.
|
||||||
|
try:
|
||||||
|
aid = seeded_asset["id"]
|
||||||
|
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
detail = rg.json()
|
||||||
|
assert rg.status_code == 200
|
||||||
|
asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors")
|
||||||
|
abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename
|
||||||
|
assert abs_path.exists()
|
||||||
|
abs_path.unlink()
|
||||||
|
|
||||||
|
r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||||
|
assert r2.status_code == 404
|
||||||
|
body = r2.json()
|
||||||
|
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||||
|
finally:
|
||||||
|
# We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually.
|
||||||
|
dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
dr.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states")
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
def test_download_404_if_all_states_missing(
|
||||||
|
root: str,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
run_scan_and_wait,
|
||||||
|
):
|
||||||
|
"""Multi-state asset: after the last remaining on-disk file is removed, download must return 404."""
|
||||||
|
scope = f"dl-404-{uuid.uuid4().hex[:6]}"
|
||||||
|
name = "missing_all_states.bin"
|
||||||
|
data = make_asset_bytes(name, 2048)
|
||||||
|
|
||||||
|
# Upload -> path1
|
||||||
|
a = asset_factory(name, [root, "unit-tests", scope], {}, data)
|
||||||
|
aid = a["id"]
|
||||||
|
|
||||||
|
base = comfy_tmp_base_dir / root / "unit-tests" / scope
|
||||||
|
p1 = base / get_asset_filename(a["asset_hash"], ".bin")
|
||||||
|
assert p1.exists()
|
||||||
|
|
||||||
|
# Seed a second state and dedupe
|
||||||
|
p2 = base / "copy" / name
|
||||||
|
p2.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
p2.write_bytes(data)
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
run_scan_and_wait(root)
|
||||||
|
|
||||||
|
# Remove first file -> download should still work via the second state
|
||||||
|
p1.unlink()
|
||||||
|
ok1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||||
|
b1 = ok1.content
|
||||||
|
assert ok1.status_code == 200 and b1 == data
|
||||||
|
|
||||||
|
# Remove the last file -> download must 404
|
||||||
|
p2.unlink()
|
||||||
|
r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120)
|
||||||
|
body = r2.json()
|
||||||
|
assert r2.status_code == 404
|
||||||
|
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||||
342
tests-unit/assets_test/test_list_filter.py
Normal file
342
tests-unit/assets_test/test_list_filter.py
Normal file
@ -0,0 +1,342 @@
|
|||||||
|
import time
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||||
|
names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"]
|
||||||
|
for n in names:
|
||||||
|
asset_factory(
|
||||||
|
n,
|
||||||
|
["models", "checkpoints", "unit-tests", "paging"],
|
||||||
|
{"epoch": 1},
|
||||||
|
make_asset_bytes(n, size=2048),
|
||||||
|
)
|
||||||
|
|
||||||
|
# name ascending for stable order
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
got1 = [a["name"] for a in b1["assets"]]
|
||||||
|
assert got1 == sorted(names)[:2]
|
||||||
|
assert b1["has_more"] is True
|
||||||
|
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
got2 = [a["name"] for a in b2["assets"]]
|
||||||
|
assert got2 == sorted(names)[2:]
|
||||||
|
assert b2["has_more"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory):
|
||||||
|
a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024)
|
||||||
|
b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024)
|
||||||
|
|
||||||
|
r = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 200
|
||||||
|
names = [x["name"] for x in body["assets"]]
|
||||||
|
assert a["name"] in names
|
||||||
|
assert b["name"] not in names
|
||||||
|
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests", "name_contains": "inc_"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
body2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
names2 = [x["name"] for x in body2["assets"]]
|
||||||
|
assert a["name"] in names2
|
||||||
|
assert b["name"] in names2
|
||||||
|
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "non-existing-tag"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
body3 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert not body3["assets"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
t = ["models", "checkpoints", "unit-tests", "lf-size"]
|
||||||
|
n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors"
|
||||||
|
asset_factory(n1, t, {}, make_asset_bytes(n1, 1024))
|
||||||
|
asset_factory(n2, t, {}, make_asset_bytes(n2, 2048))
|
||||||
|
asset_factory(n3, t, {}, make_asset_bytes(n3, 3072))
|
||||||
|
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
names = [a["name"] for a in b1["assets"]]
|
||||||
|
assert names[:3] == [n1, n2, n3]
|
||||||
|
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
names2 = [a["name"] for a in b2["assets"]]
|
||||||
|
assert names2[:3] == [n3, n2, n1]
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
t = ["models", "checkpoints", "unit-tests", "lf-upd"]
|
||||||
|
a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200))
|
||||||
|
a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200))
|
||||||
|
|
||||||
|
# Rename the second asset to bump updated_at
|
||||||
|
rp = http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}, timeout=120)
|
||||||
|
upd = rp.json()
|
||||||
|
assert rp.status_code == 200, upd
|
||||||
|
|
||||||
|
r = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 200
|
||||||
|
names = [x["name"] for x in body["assets"]]
|
||||||
|
assert names[0] == "upd_b_renamed.safetensors"
|
||||||
|
assert a1["name"] in names
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
t = ["models", "checkpoints", "unit-tests", "lf-access"]
|
||||||
|
asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100))
|
||||||
|
time.sleep(0.02)
|
||||||
|
a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100))
|
||||||
|
|
||||||
|
# Touch last_access_time of b by downloading its content
|
||||||
|
time.sleep(0.02)
|
||||||
|
dl = http.get(f"{api_base}/api/assets/{a2['id']}/content", timeout=120)
|
||||||
|
assert dl.status_code == 200
|
||||||
|
dl.content
|
||||||
|
|
||||||
|
r = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 200
|
||||||
|
names = [x["name"] for x in body["assets"]]
|
||||||
|
assert names[0] == a2["name"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
t = ["models", "checkpoints", "unit-tests", "lf-include"]
|
||||||
|
a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva"))
|
||||||
|
asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb"))
|
||||||
|
|
||||||
|
# CSV + case-insensitive
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
names1 = [x["name"] for x in b1["assets"]]
|
||||||
|
assert a["name"] in names1
|
||||||
|
assert not any("beta" in x for x in names1)
|
||||||
|
|
||||||
|
# Repeated query params for include_tags
|
||||||
|
params_multi = [
|
||||||
|
("include_tags", "unit-tests"),
|
||||||
|
("include_tags", "lf-include"),
|
||||||
|
("include_tags", "alpha"),
|
||||||
|
]
|
||||||
|
r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
names2 = [x["name"] for x in b2["assets"]]
|
||||||
|
assert a["name"] in names2
|
||||||
|
assert not any("beta" in x for x in names2)
|
||||||
|
|
||||||
|
# Duplicates and spaces in CSV
|
||||||
|
r3 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": " unit-tests , lf-include , alpha , alpha "},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b3 = r3.json()
|
||||||
|
assert r3.status_code == 200
|
||||||
|
names3 = [x["name"] for x in b3["assets"]]
|
||||||
|
assert a["name"] in names3
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
t = ["models", "checkpoints", "unit-tests", "lf-exclude"]
|
||||||
|
a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900))
|
||||||
|
asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900))
|
||||||
|
|
||||||
|
# Exclude uppercase should work
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
names1 = [x["name"] for x in b1["assets"]]
|
||||||
|
assert a["name"] in names1
|
||||||
|
# Repeated excludes with duplicates
|
||||||
|
params_multi = [
|
||||||
|
("include_tags", "unit-tests"),
|
||||||
|
("include_tags", "lf-exclude"),
|
||||||
|
("exclude_tags", "beta"),
|
||||||
|
("exclude_tags", "beta"),
|
||||||
|
]
|
||||||
|
r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
names2 = [x["name"] for x in b2["assets"]]
|
||||||
|
assert all("beta" not in x for x in names2)
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
t = ["models", "checkpoints", "unit-tests", "lf-name"]
|
||||||
|
a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800))
|
||||||
|
a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800))
|
||||||
|
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
names1 = [x["name"] for x in b1["assets"]]
|
||||||
|
assert a1["name"] in names1
|
||||||
|
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
names2 = [x["name"] for x in b2["assets"]]
|
||||||
|
assert a1["name"] in names2
|
||||||
|
|
||||||
|
r3 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b3 = r3.json()
|
||||||
|
assert r3.status_code == 200
|
||||||
|
names3 = [x["name"] for x in b3["assets"]]
|
||||||
|
assert a2["name"] in names3
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"]
|
||||||
|
asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600))
|
||||||
|
asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600))
|
||||||
|
asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600))
|
||||||
|
|
||||||
|
# Offset far beyond total
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
assert not b1["assets"]
|
||||||
|
assert b1["has_more"] is False
|
||||||
|
|
||||||
|
# Boundary large limit (<=500 is valid)
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert len(b2["assets"]) == 3
|
||||||
|
assert b2["has_more"] is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base):
|
||||||
|
r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 400
|
||||||
|
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||||
|
|
||||||
|
r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 400
|
||||||
|
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str):
|
||||||
|
# limit too small
|
||||||
|
r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 400
|
||||||
|
assert b1["error"]["code"] == "INVALID_QUERY"
|
||||||
|
|
||||||
|
# bad metadata JSON
|
||||||
|
r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 400
|
||||||
|
assert b2["error"]["code"] == "INVALID_QUERY"
|
||||||
|
|
||||||
|
|
||||||
|
def test_list_assets_name_contains_literal_underscore(
|
||||||
|
http,
|
||||||
|
api_base,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
):
|
||||||
|
"""'name_contains' must treat '_' literally, not as a SQL wildcard.
|
||||||
|
We create:
|
||||||
|
- foo_bar.safetensors (should match)
|
||||||
|
- fooxbar.safetensors (must NOT match if '_' is escaped)
|
||||||
|
- foobar.safetensors (must NOT match)
|
||||||
|
"""
|
||||||
|
scope = f"lf-underscore-{uuid.uuid4().hex[:6]}"
|
||||||
|
tags = ["models", "checkpoints", "unit-tests", scope]
|
||||||
|
|
||||||
|
a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700))
|
||||||
|
b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700))
|
||||||
|
c = asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700))
|
||||||
|
|
||||||
|
r = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 200, body
|
||||||
|
names = [x["name"] for x in body["assets"]]
|
||||||
|
assert a["name"] in names, f"Expected literal underscore match to include {a['name']}"
|
||||||
|
assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'"
|
||||||
|
assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'"
|
||||||
|
assert body["total"] == 1
|
||||||
395
tests-unit/assets_test/test_metadata_filters.py
Normal file
395
tests-unit/assets_test/test_metadata_filters.py
Normal file
@ -0,0 +1,395 @@
|
|||||||
|
import json
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_and_across_keys_and_types(
|
||||||
|
http, api_base: str, asset_factory, make_asset_bytes
|
||||||
|
):
|
||||||
|
name = "mf_and_mix.safetensors"
|
||||||
|
tags = ["models", "checkpoints", "unit-tests", "mf-and"]
|
||||||
|
meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
|
||||||
|
asset_factory(name, tags, meta, make_asset_bytes(name, 4096))
|
||||||
|
|
||||||
|
# All keys must match (AND semantics)
|
||||||
|
f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23}
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={
|
||||||
|
"include_tags": "unit-tests,mf-and",
|
||||||
|
"metadata_filter": json.dumps(f_ok),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
names = [a["name"] for a in b1["assets"]]
|
||||||
|
assert name in names
|
||||||
|
|
||||||
|
# One key mismatched -> no result
|
||||||
|
f_bad = {"purpose": "mix", "epoch": 2, "active": True}
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={
|
||||||
|
"include_tags": "unit-tests,mf-and",
|
||||||
|
"metadata_filter": json.dumps(f_bad),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert not b2["assets"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
name = "mf_types.safetensors"
|
||||||
|
tags = ["models", "checkpoints", "unit-tests", "mf-types"]
|
||||||
|
meta = {"epoch": 1, "active": True}
|
||||||
|
asset_factory(name, tags, meta, make_asset_bytes(name))
|
||||||
|
|
||||||
|
# int filter matches numeric
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={
|
||||||
|
"include_tags": "unit-tests,mf-types",
|
||||||
|
"metadata_filter": json.dumps({"epoch": 1}),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||||
|
|
||||||
|
# string "1" must NOT match numeric 1
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={
|
||||||
|
"include_tags": "unit-tests,mf-types",
|
||||||
|
"metadata_filter": json.dumps({"epoch": "1"}),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200 and not b2["assets"]
|
||||||
|
|
||||||
|
# bool True matches, string "true" must NOT match
|
||||||
|
r3 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={
|
||||||
|
"include_tags": "unit-tests,mf-types",
|
||||||
|
"metadata_filter": json.dumps({"active": True}),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b3 = r3.json()
|
||||||
|
assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"])
|
||||||
|
|
||||||
|
r4 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={
|
||||||
|
"include_tags": "unit-tests,mf-types",
|
||||||
|
"metadata_filter": json.dumps({"active": "true"}),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b4 = r4.json()
|
||||||
|
assert r4.status_code == 200 and not b4["assets"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
name = "mf_list_scalars.safetensors"
|
||||||
|
tags = ["models", "checkpoints", "unit-tests", "mf-list"]
|
||||||
|
meta = {"flags": ["red", "green"]}
|
||||||
|
asset_factory(name, tags, meta, make_asset_bytes(name, 3000))
|
||||||
|
|
||||||
|
# Any-of should match because "green" is present
|
||||||
|
filt_ok = {"flags": ["blue", "green"]}
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||||
|
|
||||||
|
# None of provided flags present -> no match
|
||||||
|
filt_miss = {"flags": ["blue", "yellow"]}
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200 and not b2["assets"]
|
||||||
|
|
||||||
|
# Duplicates in list should not break matching
|
||||||
|
filt_dup = {"flags": ["green", "green", "green"]}
|
||||||
|
r3 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b3 = r3.json()
|
||||||
|
assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_none_semantics_missing_or_null_and_any_of_with_none(
|
||||||
|
http, api_base, asset_factory, make_asset_bytes
|
||||||
|
):
|
||||||
|
# a1: key missing; a2: explicit null; a3: concrete value
|
||||||
|
t = ["models", "checkpoints", "unit-tests", "mf-none"]
|
||||||
|
a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1"))
|
||||||
|
a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2"))
|
||||||
|
a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3"))
|
||||||
|
|
||||||
|
# Filter {maybe: None} must match a1 and a2, not a3
|
||||||
|
filt = {"maybe": None}
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
got = [a["name"] for a in b1["assets"]]
|
||||||
|
assert a1["name"] in got and a2["name"] in got and a3["name"] not in got
|
||||||
|
|
||||||
|
# Any-of with None should include missing/null plus value matches
|
||||||
|
filt_any = {"maybe": [None, "x"]}
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
got2 = [a["name"] for a in b2["assets"]]
|
||||||
|
assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
name = "mf_nested_json.safetensors"
|
||||||
|
tags = ["models", "checkpoints", "unit-tests", "mf-nested"]
|
||||||
|
cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}}
|
||||||
|
asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200))
|
||||||
|
|
||||||
|
# Exact JSON object equality (same structure)
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={
|
||||||
|
"include_tags": "unit-tests,mf-nested",
|
||||||
|
"metadata_filter": json.dumps({"config": cfg}),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||||
|
|
||||||
|
# Different JSON object should not match
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={
|
||||||
|
"include_tags": "unit-tests,mf-nested",
|
||||||
|
"metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200 and not b2["assets"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
name = "mf_list_objects.safetensors"
|
||||||
|
tags = ["models", "checkpoints", "unit-tests", "mf-objlist"]
|
||||||
|
transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}]
|
||||||
|
asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048))
|
||||||
|
|
||||||
|
# Any-of for list of objects should match when one element equals the filter object
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={
|
||||||
|
"include_tags": "unit-tests,mf-objlist",
|
||||||
|
"metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||||
|
|
||||||
|
# Non-matching object -> no match
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={
|
||||||
|
"include_tags": "unit-tests,mf-objlist",
|
||||||
|
"metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}),
|
||||||
|
},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200 and not b2["assets"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
name = "mf_keys_unicode.safetensors"
|
||||||
|
tags = ["models", "checkpoints", "unit-tests", "mf-keys"]
|
||||||
|
meta = {
|
||||||
|
"weird.key": "v1",
|
||||||
|
"path/like": 7,
|
||||||
|
"with:colon": True,
|
||||||
|
"ключ": "значение",
|
||||||
|
"emoji": "🐍",
|
||||||
|
}
|
||||||
|
asset_factory(name, tags, meta, make_asset_bytes(name, 1500))
|
||||||
|
|
||||||
|
# Match all the special keys
|
||||||
|
filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"}
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||||
|
|
||||||
|
# Unicode key match
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200 and any(a["name"] == name for a in b2["assets"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"]
|
||||||
|
a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025))
|
||||||
|
a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026))
|
||||||
|
|
||||||
|
# count == 0 must match only a0
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
names1 = [a["name"] for a in b1["assets"]]
|
||||||
|
assert a0["name"] in names1 and a1["name"] not in names1
|
||||||
|
|
||||||
|
# Any-of list of booleans: True matches second asset
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200 and any(a["name"] == a1["name"] for a in b2["assets"])
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
name = "mf_mixed_list.safetensors"
|
||||||
|
tags = ["models", "checkpoints", "unit-tests", "mf-mixed"]
|
||||||
|
meta = {"mix": ["1", 1, True, None]}
|
||||||
|
asset_factory(name, tags, meta, make_asset_bytes(name, 1999))
|
||||||
|
|
||||||
|
# Should match because 1 is present
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"])
|
||||||
|
|
||||||
|
# Should NOT match for False
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200 and not b2["assets"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
# Use a unique scope tag to avoid interference
|
||||||
|
t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"]
|
||||||
|
x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua"))
|
||||||
|
y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub"))
|
||||||
|
|
||||||
|
# Filtering by unknown key with None should return both (missing key OR null)
|
||||||
|
r1 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
names = {a["name"] for a in b1["assets"]}
|
||||||
|
assert x["name"] in names and y["name"] in names
|
||||||
|
|
||||||
|
# Filtering by unknown key with concrete value should return none
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200 and not b2["assets"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
# alpha matches epoch=1; beta has epoch=2
|
||||||
|
a = asset_factory(
|
||||||
|
"mf_tag_alpha.safetensors",
|
||||||
|
["models", "checkpoints", "unit-tests", "mf-tag", "alpha"],
|
||||||
|
{"epoch": 1},
|
||||||
|
make_asset_bytes("alpha"),
|
||||||
|
)
|
||||||
|
b = asset_factory(
|
||||||
|
"mf_tag_beta.safetensors",
|
||||||
|
["models", "checkpoints", "unit-tests", "mf-tag", "beta"],
|
||||||
|
{"epoch": 2},
|
||||||
|
make_asset_bytes("beta"),
|
||||||
|
)
|
||||||
|
|
||||||
|
params = {
|
||||||
|
"include_tags": "unit-tests,mf-tag,alpha",
|
||||||
|
"exclude_tags": "beta",
|
||||||
|
"name_contains": "mf_tag_",
|
||||||
|
"metadata_filter": json.dumps({"epoch": 1}),
|
||||||
|
}
|
||||||
|
r = http.get(api_base + "/api/assets", params=params, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 200
|
||||||
|
names = [x["name"] for x in body["assets"]]
|
||||||
|
assert a["name"] in names
|
||||||
|
assert b["name"] not in names
|
||||||
|
|
||||||
|
|
||||||
|
def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes):
|
||||||
|
# Three assets in same scope with different sizes and a common filter key
|
||||||
|
t = ["models", "checkpoints", "unit-tests", "mf-sort"]
|
||||||
|
n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors"
|
||||||
|
asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024))
|
||||||
|
asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048))
|
||||||
|
asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072))
|
||||||
|
|
||||||
|
# Sort by size ascending with paging
|
||||||
|
q = {
|
||||||
|
"include_tags": "unit-tests,mf-sort",
|
||||||
|
"metadata_filter": json.dumps({"group": "g"}),
|
||||||
|
"sort": "size", "order": "asc", "limit": "2",
|
||||||
|
}
|
||||||
|
r1 = http.get(api_base + "/api/assets", params=q, timeout=120)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
got1 = [a["name"] for a in b1["assets"]]
|
||||||
|
assert got1 == [n1, n2]
|
||||||
|
assert b1["has_more"] is True
|
||||||
|
|
||||||
|
q2 = {**q, "offset": "2"}
|
||||||
|
r2 = http.get(api_base + "/api/assets", params=q2, timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
got2 = [a["name"] for a in b2["assets"]]
|
||||||
|
assert got2 == [n3]
|
||||||
|
assert b2["has_more"] is False
|
||||||
141
tests-unit/assets_test/test_prune_orphaned_assets.py
Normal file
141
tests-unit/assets_test/test_prune_orphaned_assets.py
Normal file
@ -0,0 +1,141 @@
|
|||||||
|
import uuid
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import requests
|
||||||
|
from conftest import get_asset_filename, trigger_sync_seed_assets
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def create_seed_file(comfy_tmp_base_dir: Path):
|
||||||
|
"""Create a file on disk that will become a seed asset after sync."""
|
||||||
|
created: list[Path] = []
|
||||||
|
|
||||||
|
def _create(root: str, scope: str, name: str | None = None, data: bytes = b"TEST") -> Path:
|
||||||
|
name = name or f"seed_{uuid.uuid4().hex[:8]}.bin"
|
||||||
|
path = comfy_tmp_base_dir / root / "unit-tests" / scope / name
|
||||||
|
path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
path.write_bytes(data)
|
||||||
|
created.append(path)
|
||||||
|
return path
|
||||||
|
|
||||||
|
yield _create
|
||||||
|
|
||||||
|
for p in created:
|
||||||
|
p.unlink(missing_ok=True)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def find_asset(http: requests.Session, api_base: str):
|
||||||
|
"""Query API for assets matching scope and optional name."""
|
||||||
|
def _find(scope: str, name: str | None = None) -> list[dict]:
|
||||||
|
params = {"include_tags": f"unit-tests,{scope}"}
|
||||||
|
if name:
|
||||||
|
params["name_contains"] = name
|
||||||
|
r = http.get(f"{api_base}/api/assets", params=params, timeout=120)
|
||||||
|
assert r.status_code == 200
|
||||||
|
assets = r.json().get("assets", [])
|
||||||
|
if name:
|
||||||
|
return [a for a in assets if a.get("name") == name]
|
||||||
|
return assets
|
||||||
|
|
||||||
|
return _find
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
def test_orphaned_seed_asset_is_pruned(
|
||||||
|
root: str,
|
||||||
|
create_seed_file,
|
||||||
|
find_asset,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
):
|
||||||
|
"""Seed asset with deleted file is removed; with file present, it survives."""
|
||||||
|
scope = f"prune-{uuid.uuid4().hex[:6]}"
|
||||||
|
fp = create_seed_file(root, scope)
|
||||||
|
name = fp.name
|
||||||
|
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
assert find_asset(scope, name), "Seed asset should exist"
|
||||||
|
|
||||||
|
fp.unlink()
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
assert not find_asset(scope, name), "Orphaned seed should be pruned"
|
||||||
|
|
||||||
|
|
||||||
|
def test_seed_asset_with_file_survives_prune(
|
||||||
|
create_seed_file,
|
||||||
|
find_asset,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
):
|
||||||
|
"""Seed asset with file still on disk is NOT pruned."""
|
||||||
|
scope = f"keep-{uuid.uuid4().hex[:6]}"
|
||||||
|
fp = create_seed_file("input", scope)
|
||||||
|
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
assert find_asset(scope, fp.name), "Seed with valid file should survive"
|
||||||
|
|
||||||
|
|
||||||
|
def test_hashed_asset_not_pruned_when_file_missing(
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
):
|
||||||
|
"""Hashed assets are never deleted by prune, even without file."""
|
||||||
|
scope = f"hashed-{uuid.uuid4().hex[:6]}"
|
||||||
|
data = make_asset_bytes("test", 2048)
|
||||||
|
a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data)
|
||||||
|
|
||||||
|
path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin")
|
||||||
|
path.unlink()
|
||||||
|
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
r = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120)
|
||||||
|
assert r.status_code == 200, "Hashed asset should NOT be pruned"
|
||||||
|
|
||||||
|
|
||||||
|
def test_prune_across_multiple_roots(
|
||||||
|
create_seed_file,
|
||||||
|
find_asset,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
):
|
||||||
|
"""Prune correctly handles assets across input and output roots."""
|
||||||
|
scope = f"multi-{uuid.uuid4().hex[:6]}"
|
||||||
|
input_fp = create_seed_file("input", scope, "input.bin")
|
||||||
|
create_seed_file("output", scope, "output.bin")
|
||||||
|
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
assert len(find_asset(scope)) == 2
|
||||||
|
|
||||||
|
input_fp.unlink()
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
remaining = find_asset(scope)
|
||||||
|
assert len(remaining) == 1
|
||||||
|
assert remaining[0]["name"] == "output.bin"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"])
|
||||||
|
def test_special_chars_in_path_escaped_correctly(
|
||||||
|
dirname: str,
|
||||||
|
create_seed_file,
|
||||||
|
find_asset,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
comfy_tmp_base_dir: Path,
|
||||||
|
):
|
||||||
|
"""SQL LIKE wildcards (%, _) and spaces in paths don't cause false matches."""
|
||||||
|
scope = f"special-{uuid.uuid4().hex[:6]}/{dirname}"
|
||||||
|
fp = create_seed_file("input", scope)
|
||||||
|
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
trigger_sync_seed_assets(http, api_base)
|
||||||
|
|
||||||
|
assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive"
|
||||||
225
tests-unit/assets_test/test_tags.py
Normal file
225
tests-unit/assets_test/test_tags.py
Normal file
@ -0,0 +1,225 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
|
||||||
|
def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||||
|
# Include zero-usage tags by default
|
||||||
|
r1 = http.get(api_base + "/api/tags", params={"limit": "50"}, timeout=120)
|
||||||
|
body1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
names = [t["name"] for t in body1["tags"]]
|
||||||
|
# A few system tags from migration should exist:
|
||||||
|
assert "models" in names
|
||||||
|
assert "checkpoints" in names
|
||||||
|
|
||||||
|
# Only used tags before we add anything new from this test cycle
|
||||||
|
r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120)
|
||||||
|
body2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
# We already seeded one asset via fixture, so used tags must be non-empty
|
||||||
|
used_names = [t["name"] for t in body2["tags"]]
|
||||||
|
assert "models" in used_names
|
||||||
|
assert "checkpoints" in used_names
|
||||||
|
|
||||||
|
# Prefix filter should refine the list
|
||||||
|
r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120)
|
||||||
|
b3 = r3.json()
|
||||||
|
assert r3.status_code == 200
|
||||||
|
names3 = [t["name"] for t in b3["tags"]]
|
||||||
|
assert "unit-tests" in names3
|
||||||
|
assert "models" not in names3 # filtered out by prefix
|
||||||
|
|
||||||
|
# Order by name ascending should be stable
|
||||||
|
r4 = http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}, timeout=120)
|
||||||
|
b4 = r4.json()
|
||||||
|
assert r4.status_code == 200
|
||||||
|
names4 = [t["name"] for t in b4["tags"]]
|
||||||
|
assert names4 == sorted(names4)
|
||||||
|
|
||||||
|
|
||||||
|
def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||||
|
# Baseline: system tags exist when include_zero (default) is true
|
||||||
|
r1 = http.get(api_base + "/api/tags", params={"limit": "500"}, timeout=120)
|
||||||
|
body1 = r1.json()
|
||||||
|
assert r1.status_code == 200
|
||||||
|
names = [t["name"] for t in body1["tags"]]
|
||||||
|
assert "models" in names and "checkpoints" in names
|
||||||
|
|
||||||
|
# Create a short-lived asset under input with a unique custom tag
|
||||||
|
scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}"
|
||||||
|
custom_tag = f"temp-{uuid.uuid4().hex[:8]}"
|
||||||
|
name = "tag_seed.bin"
|
||||||
|
_asset = asset_factory(
|
||||||
|
name,
|
||||||
|
["input", "unit-tests", scope, custom_tag],
|
||||||
|
{},
|
||||||
|
make_asset_bytes(name, 512),
|
||||||
|
)
|
||||||
|
|
||||||
|
# While the asset exists, the custom tag must appear when include_zero=false
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/tags",
|
||||||
|
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
body2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
used_names = [t["name"] for t in body2["tags"]]
|
||||||
|
assert custom_tag in used_names
|
||||||
|
|
||||||
|
# Delete the asset so the tag usage drops to zero
|
||||||
|
rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120)
|
||||||
|
assert rd.status_code == 204
|
||||||
|
|
||||||
|
# Now the custom tag must not be returned when include_zero=false
|
||||||
|
r3 = http.get(
|
||||||
|
api_base + "/api/tags",
|
||||||
|
params={"include_zero": "false", "prefix": custom_tag, "limit": "50"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
body3 = r3.json()
|
||||||
|
assert r3.status_code == 200
|
||||||
|
names_after = [t["name"] for t in body3["tags"]]
|
||||||
|
assert custom_tag not in names_after
|
||||||
|
assert not names_after # filtered view should be empty now
|
||||||
|
|
||||||
|
|
||||||
|
def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||||
|
aid = seeded_asset["id"]
|
||||||
|
|
||||||
|
# Add tags with duplicates and mixed case
|
||||||
|
payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]}
|
||||||
|
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200, b1
|
||||||
|
# normalized, deduplicated; 'unit-tests' was already present from the seed
|
||||||
|
assert set(b1["added"]) == {"newtag", "beta"}
|
||||||
|
assert set(b1["already_present"]) == {"unit-tests"}
|
||||||
|
assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"]
|
||||||
|
|
||||||
|
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
g = rg.json()
|
||||||
|
assert rg.status_code == 200
|
||||||
|
tags_now = set(g["tags"])
|
||||||
|
assert {"newtag", "beta"}.issubset(tags_now)
|
||||||
|
|
||||||
|
# Remove a tag and a non-existent tag
|
||||||
|
payload_del = {"tags": ["newtag", "does-not-exist"]}
|
||||||
|
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert set(b2["removed"]) == {"newtag"}
|
||||||
|
assert set(b2["not_present"]) == {"does-not-exist"}
|
||||||
|
|
||||||
|
# Verify remaining tags after deletion
|
||||||
|
rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
g2 = rg2.json()
|
||||||
|
assert rg2.status_code == 200
|
||||||
|
tags_later = set(g2["tags"])
|
||||||
|
assert "newtag" not in tags_later
|
||||||
|
assert "beta" in tags_later # still present
|
||||||
|
|
||||||
|
|
||||||
|
def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||||
|
aid = seeded_asset["id"]
|
||||||
|
h = seeded_asset["asset_hash"]
|
||||||
|
|
||||||
|
# Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1)
|
||||||
|
r_add = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}, timeout=120)
|
||||||
|
add_body = r_add.json()
|
||||||
|
assert r_add.status_code == 200, add_body
|
||||||
|
|
||||||
|
# Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'.
|
||||||
|
payload = {
|
||||||
|
"hash": h,
|
||||||
|
"name": "order_only_bbb.safetensors",
|
||||||
|
"tags": ["input", "unit-tests", "orderbbb"],
|
||||||
|
"user_metadata": {},
|
||||||
|
}
|
||||||
|
r_copy = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||||
|
copy_body = r_copy.json()
|
||||||
|
assert r_copy.status_code == 201, copy_body
|
||||||
|
|
||||||
|
# 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa'
|
||||||
|
# because it has higher usage (2 vs 1).
|
||||||
|
r1 = http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}, timeout=120)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 200, b1
|
||||||
|
names1 = [t["name"] for t in b1["tags"]]
|
||||||
|
counts1 = {t["name"]: t["count"] for t in b1["tags"]}
|
||||||
|
# Both must be present within the prefix subset
|
||||||
|
assert "orderaaa" in names1 and "orderbbb" in names1
|
||||||
|
# Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1
|
||||||
|
assert counts1["orderbbb"] >= counts1["orderaaa"]
|
||||||
|
# And with count_desc, 'orderbbb' appears earlier than 'orderaaa'
|
||||||
|
assert names1.index("orderbbb") < names1.index("orderaaa")
|
||||||
|
|
||||||
|
# 2) name_asc: lexical order should flip the relative order
|
||||||
|
r2 = http.get(
|
||||||
|
api_base + "/api/tags",
|
||||||
|
params={"prefix": "order", "include_zero": "false", "order": "name_asc"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200, b2
|
||||||
|
names2 = [t["name"] for t in b2["tags"]]
|
||||||
|
assert "orderaaa" in names2 and "orderbbb" in names2
|
||||||
|
assert names2.index("orderaaa") < names2.index("orderbbb")
|
||||||
|
|
||||||
|
# 3) invalid limit rejected (existing negative case retained)
|
||||||
|
r3 = http.get(api_base + "/api/tags", params={"limit": "1001"}, timeout=120)
|
||||||
|
b3 = r3.json()
|
||||||
|
assert r3.status_code == 400
|
||||||
|
assert b3["error"]["code"] == "INVALID_QUERY"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tags_endpoints_invalid_bodies(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||||
|
aid = seeded_asset["id"]
|
||||||
|
|
||||||
|
# Add with empty list
|
||||||
|
r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}, timeout=120)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 400
|
||||||
|
assert b1["error"]["code"] == "INVALID_BODY"
|
||||||
|
|
||||||
|
# Remove with wrong type
|
||||||
|
r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}, timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 400
|
||||||
|
assert b2["error"]["code"] == "INVALID_BODY"
|
||||||
|
|
||||||
|
# metadata_filter provided as JSON array should be rejected (must be object)
|
||||||
|
r3 = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"metadata_filter": json.dumps([{"x": 1}])},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
b3 = r3.json()
|
||||||
|
assert r3.status_code == 400
|
||||||
|
assert b3["error"]["code"] == "INVALID_QUERY"
|
||||||
|
|
||||||
|
|
||||||
|
def test_tags_prefix_treats_underscore_literal(
|
||||||
|
http,
|
||||||
|
api_base,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
):
|
||||||
|
"""'prefix' for /api/tags must treat '_' literally, not as a wildcard."""
|
||||||
|
base = f"pref_{uuid.uuid4().hex[:6]}"
|
||||||
|
tag_ok = f"{base}_ok" # should match prefix=f"{base}_"
|
||||||
|
tag_bad = f"{base}xok" # must NOT match if '_' is escaped
|
||||||
|
scope = f"tags-underscore-{uuid.uuid4().hex[:6]}"
|
||||||
|
|
||||||
|
asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512))
|
||||||
|
asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512))
|
||||||
|
|
||||||
|
r = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 200, body
|
||||||
|
names = [t["name"] for t in body["tags"]]
|
||||||
|
assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'"
|
||||||
|
assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard"
|
||||||
|
assert body["total"] == 1
|
||||||
281
tests-unit/assets_test/test_uploads.py
Normal file
281
tests-unit/assets_test/test_uploads.py
Normal file
@ -0,0 +1,281 @@
|
|||||||
|
import json
|
||||||
|
import uuid
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
|
||||||
|
import requests
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes):
|
||||||
|
name = "dup_a.safetensors"
|
||||||
|
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||||
|
meta = {"purpose": "dup"}
|
||||||
|
data = make_asset_bytes(name)
|
||||||
|
files = {"file": (name, data, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
|
||||||
|
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
a1 = r1.json()
|
||||||
|
assert r1.status_code == 201, a1
|
||||||
|
assert a1["created_new"] is True
|
||||||
|
|
||||||
|
# Second upload with the same data and name should return created_new == False and the same asset
|
||||||
|
files = {"file": (name, data, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
|
||||||
|
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
a2 = r2.json()
|
||||||
|
assert r2.status_code == 200, a2
|
||||||
|
assert a2["created_new"] is False
|
||||||
|
assert a2["asset_hash"] == a1["asset_hash"]
|
||||||
|
assert a2["id"] == a1["id"] # old reference
|
||||||
|
|
||||||
|
# Third upload with the same data but new name should return created_new == False and the new AssetReference
|
||||||
|
files = {"file": (name, data, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps(tags), "name": name + "_d", "user_metadata": json.dumps(meta)}
|
||||||
|
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
a3 = r2.json()
|
||||||
|
assert r2.status_code == 200, a3
|
||||||
|
assert a3["created_new"] is False
|
||||||
|
assert a3["asset_hash"] == a1["asset_hash"]
|
||||||
|
assert a3["id"] != a1["id"] # old reference
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str):
|
||||||
|
# Seed a small file first
|
||||||
|
name = "fastpath_seed.safetensors"
|
||||||
|
tags = ["models", "checkpoints", "unit-tests"]
|
||||||
|
meta = {}
|
||||||
|
files = {"file": (name, b"B" * 1024, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)}
|
||||||
|
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 201, b1
|
||||||
|
h = b1["asset_hash"]
|
||||||
|
|
||||||
|
# Now POST /api/assets with only hash and no file
|
||||||
|
files = [
|
||||||
|
("hash", (None, h)),
|
||||||
|
("tags", (None, json.dumps(tags))),
|
||||||
|
("name", (None, "fastpath_copy.safetensors")),
|
||||||
|
("user_metadata", (None, json.dumps({"purpose": "copy"}))),
|
||||||
|
]
|
||||||
|
r2 = http.post(api_base + "/api/assets", files=files, timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200, b2 # fast path returns 200 with created_new == False
|
||||||
|
assert b2["created_new"] is False
|
||||||
|
assert b2["asset_hash"] == h
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_fastpath_with_known_hash_and_file(
|
||||||
|
http: requests.Session, api_base: str
|
||||||
|
):
|
||||||
|
# Seed
|
||||||
|
files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})}
|
||||||
|
r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
b1 = r1.json()
|
||||||
|
assert r1.status_code == 201, b1
|
||||||
|
h = b1["asset_hash"]
|
||||||
|
|
||||||
|
# Send both file and hash of existing content -> server must drain file and create from hash (200)
|
||||||
|
files = {"file": ("ignored.bin", b"ignored" * 10, "application/octet-stream")}
|
||||||
|
form = {"hash": h, "tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "copy_from_hash.safetensors", "user_metadata": json.dumps({})}
|
||||||
|
r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
b2 = r2.json()
|
||||||
|
assert r2.status_code == 200, b2
|
||||||
|
assert b2["created_new"] is False
|
||||||
|
assert b2["asset_hash"] == h
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str):
|
||||||
|
data = [
|
||||||
|
("tags", "models,checkpoints"),
|
||||||
|
("tags", json.dumps(["unit-tests", "alpha"])),
|
||||||
|
("name", "merge.safetensors"),
|
||||||
|
("user_metadata", json.dumps({"u": 1})),
|
||||||
|
]
|
||||||
|
files = {"file": ("merge.safetensors", b"B" * 256, "application/octet-stream")}
|
||||||
|
r1 = http.post(api_base + "/api/assets", data=data, files=files, timeout=120)
|
||||||
|
created = r1.json()
|
||||||
|
assert r1.status_code in (200, 201), created
|
||||||
|
aid = created["id"]
|
||||||
|
|
||||||
|
# Verify all tags are present on the resource
|
||||||
|
rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||||
|
detail = rg.json()
|
||||||
|
assert rg.status_code == 200, detail
|
||||||
|
tags = set(detail["tags"])
|
||||||
|
assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
def test_concurrent_upload_identical_bytes_different_names(
|
||||||
|
root: str,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
make_asset_bytes,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Two concurrent uploads of identical bytes but different names.
|
||||||
|
Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True.
|
||||||
|
"""
|
||||||
|
scope = f"concupload-{uuid.uuid4().hex[:6]}"
|
||||||
|
name1, name2 = "cu_a.bin", "cu_b.bin"
|
||||||
|
data = make_asset_bytes("concurrent", 4096)
|
||||||
|
tags = [root, "unit-tests", scope]
|
||||||
|
|
||||||
|
def _do_upload(args):
|
||||||
|
url, form_data, files_data = args
|
||||||
|
with requests.Session() as s:
|
||||||
|
return s.post(url, data=form_data, files=files_data, timeout=120)
|
||||||
|
|
||||||
|
url = api_base + "/api/assets"
|
||||||
|
form1 = {"tags": json.dumps(tags), "name": name1, "user_metadata": json.dumps({})}
|
||||||
|
files1 = {"file": (name1, data, "application/octet-stream")}
|
||||||
|
form2 = {"tags": json.dumps(tags), "name": name2, "user_metadata": json.dumps({})}
|
||||||
|
files2 = {"file": (name2, data, "application/octet-stream")}
|
||||||
|
|
||||||
|
with ThreadPoolExecutor(max_workers=2) as executor:
|
||||||
|
futures = list(executor.map(_do_upload, [(url, form1, files1), (url, form2, files2)]))
|
||||||
|
r1, r2 = futures
|
||||||
|
|
||||||
|
b1, b2 = r1.json(), r2.json()
|
||||||
|
assert r1.status_code in (200, 201), b1
|
||||||
|
assert r2.status_code in (200, 201), b2
|
||||||
|
assert b1["asset_hash"] == b2["asset_hash"]
|
||||||
|
assert b1["id"] != b2["id"]
|
||||||
|
|
||||||
|
created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))])
|
||||||
|
assert created_flags == [False, True]
|
||||||
|
|
||||||
|
rl = http.get(
|
||||||
|
api_base + "/api/assets",
|
||||||
|
params={"include_tags": f"unit-tests,{scope}", "sort": "name"},
|
||||||
|
timeout=120,
|
||||||
|
)
|
||||||
|
bl = rl.json()
|
||||||
|
assert rl.status_code == 200, bl
|
||||||
|
names = [a["name"] for a in bl.get("assets", [])]
|
||||||
|
assert set([name1, name2]).issubset(names)
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str):
|
||||||
|
payload = {
|
||||||
|
"hash": "blake3:" + "0" * 64,
|
||||||
|
"name": "nonexistent.bin",
|
||||||
|
"tags": ["models", "checkpoints", "unit-tests"],
|
||||||
|
}
|
||||||
|
r = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 404
|
||||||
|
assert body["error"]["code"] == "ASSET_NOT_FOUND"
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_zero_byte_rejected(http: requests.Session, api_base: str):
|
||||||
|
files = {"file": ("empty.safetensors", b"", "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})}
|
||||||
|
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 400
|
||||||
|
assert body["error"]["code"] == "EMPTY_UPLOAD"
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str):
|
||||||
|
files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})}
|
||||||
|
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 400
|
||||||
|
assert body["error"]["code"] == "INVALID_BODY"
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str):
|
||||||
|
files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"}
|
||||||
|
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 400
|
||||||
|
assert body["error"]["code"] == "INVALID_BODY"
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_requires_multipart(http: requests.Session, api_base: str):
|
||||||
|
r = http.post(api_base + "/api/assets", json={"foo": "bar"}, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 415
|
||||||
|
assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_missing_file_and_hash(http: requests.Session, api_base: str):
|
||||||
|
files = [
|
||||||
|
("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))),
|
||||||
|
("name", (None, "x.safetensors")),
|
||||||
|
]
|
||||||
|
r = http.post(api_base + "/api/assets", files=files, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 400
|
||||||
|
assert body["error"]["code"] == "MISSING_FILE"
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_models_unknown_category(http: requests.Session, api_base: str):
|
||||||
|
files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"}
|
||||||
|
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 400
|
||||||
|
assert body["error"]["code"] == "INVALID_BODY"
|
||||||
|
assert body["error"]["message"].startswith("unknown models category")
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_models_requires_category(http: requests.Session, api_base: str):
|
||||||
|
files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})}
|
||||||
|
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 400
|
||||||
|
assert body["error"]["code"] == "INVALID_BODY"
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
|
||||||
|
files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"}
|
||||||
|
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 400
|
||||||
|
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
|
def test_duplicate_upload_same_display_name_does_not_clobber(
|
||||||
|
root: str,
|
||||||
|
http: requests.Session,
|
||||||
|
api_base: str,
|
||||||
|
asset_factory,
|
||||||
|
make_asset_bytes,
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Two uploads use the same tags and the same display name but different bytes.
|
||||||
|
With hash-based filenames, they must NOT overwrite each other. Both assets
|
||||||
|
remain accessible and serve their original content.
|
||||||
|
"""
|
||||||
|
scope = f"dup-path-{uuid.uuid4().hex[:6]}"
|
||||||
|
display_name = "same_display.bin"
|
||||||
|
|
||||||
|
d1 = make_asset_bytes(scope + "-v1", 1536)
|
||||||
|
d2 = make_asset_bytes(scope + "-v2", 2048)
|
||||||
|
tags = [root, "unit-tests", scope]
|
||||||
|
|
||||||
|
first = asset_factory(display_name, tags, {}, d1)
|
||||||
|
second = asset_factory(display_name, tags, {}, d2)
|
||||||
|
|
||||||
|
assert first["id"] != second["id"]
|
||||||
|
assert first["asset_hash"] != second["asset_hash"] # different content
|
||||||
|
assert first["name"] == second["name"] == display_name
|
||||||
|
|
||||||
|
# Both must be independently retrievable
|
||||||
|
r1 = http.get(f"{api_base}/api/assets/{first['id']}/content", timeout=120)
|
||||||
|
b1 = r1.content
|
||||||
|
assert r1.status_code == 200
|
||||||
|
assert b1 == d1
|
||||||
|
r2 = http.get(f"{api_base}/api/assets/{second['id']}/content", timeout=120)
|
||||||
|
b2 = r2.content
|
||||||
|
assert r2.status_code == 200
|
||||||
|
assert b2 == d2
|
||||||
@ -2,3 +2,4 @@ pytest>=7.8.0
|
|||||||
pytest-aiohttp
|
pytest-aiohttp
|
||||||
pytest-asyncio
|
pytest-asyncio
|
||||||
websocket-client
|
websocket-client
|
||||||
|
blake3
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user