From 6ea8c128a3770763f150e97aca4b29bd2aec60ad Mon Sep 17 00:00:00 2001 From: Jedrzej Kosinski Date: Fri, 30 Jan 2026 23:22:05 -0800 Subject: [PATCH] Assets Part 2 - add more endpoints (#12125) --- app/assets/api/routes.py | 414 +++++++++- app/assets/api/schemas_in.py | 196 ++++- app/assets/api/schemas_out.py | 33 + app/assets/database/queries.py | 719 +++++++++++++++++- app/assets/helpers.py | 97 ++- app/assets/manager.py | 401 +++++++++- app/assets/scanner.py | 36 +- tests-unit/assets_test/conftest.py | 271 +++++++ .../assets_test/test_assets_missing_sync.py | 348 +++++++++ tests-unit/assets_test/test_crud.py | 306 ++++++++ tests-unit/assets_test/test_downloads.py | 166 ++++ tests-unit/assets_test/test_list_filter.py | 342 +++++++++ .../assets_test/test_metadata_filters.py | 395 ++++++++++ .../assets_test/test_prune_orphaned_assets.py | 141 ++++ tests-unit/assets_test/test_tags.py | 225 ++++++ tests-unit/assets_test/test_uploads.py | 281 +++++++ tests-unit/requirements.txt | 1 + 17 files changed, 4347 insertions(+), 25 deletions(-) create mode 100644 tests-unit/assets_test/conftest.py create mode 100644 tests-unit/assets_test/test_assets_missing_sync.py create mode 100644 tests-unit/assets_test/test_crud.py create mode 100644 tests-unit/assets_test/test_downloads.py create mode 100644 tests-unit/assets_test/test_list_filter.py create mode 100644 tests-unit/assets_test/test_metadata_filters.py create mode 100644 tests-unit/assets_test/test_prune_orphaned_assets.py create mode 100644 tests-unit/assets_test/test_tags.py create mode 100644 tests-unit/assets_test/test_uploads.py diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 30e87a898..7676e50b4 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -1,5 +1,8 @@ import logging import uuid +import urllib.parse +import os +import contextlib from aiohttp import web from pydantic import ValidationError @@ -8,6 +11,9 @@ import app.assets.manager as manager from app import user_manager from app.assets.api import schemas_in from app.assets.helpers import get_query_dict +from app.assets.scanner import seed_assets + +import folder_paths ROUTES = web.RouteTableDef() USER_MANAGER: user_manager.UserManager | None = None @@ -15,6 +21,9 @@ USER_MANAGER: user_manager.UserManager | None = None # UUID regex (canonical hyphenated form, case-insensitive) UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" +# Note to any custom node developers reading this code: +# The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same. + def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: global USER_MANAGER USER_MANAGER = user_manager_instance @@ -28,6 +37,18 @@ def _validation_error_response(code: str, ve: ValidationError) -> web.Response: return _error_response(400, code, "Validation failed.", {"errors": ve.json()}) +@ROUTES.head("/api/assets/hash/{hash}") +async def head_asset_by_hash(request: web.Request) -> web.Response: + hash_str = request.match_info.get("hash", "").strip().lower() + if not hash_str or ":" not in hash_str: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + algo, digest = hash_str.split(":", 1) + if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + exists = manager.asset_exists(asset_hash=hash_str) + return web.Response(status=200 if exists else 404) + + @ROUTES.get("/api/assets") async def list_assets(request: web.Request) -> web.Response: """ @@ -50,7 +71,7 @@ async def list_assets(request: web.Request) -> web.Response: order=q.order, owner_id=USER_MANAGER.get_request_user_id(request), ) - return web.json_response(payload.model_dump(mode="json")) + return web.json_response(payload.model_dump(mode="json", exclude_none=True)) @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}") @@ -76,6 +97,314 @@ async def get_asset(request: web.Request) -> web.Response: return web.json_response(result.model_dump(mode="json"), status=200) +@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") +async def download_asset_content(request: web.Request) -> web.Response: + # question: do we need disposition? could we just stick with one of these? + disposition = request.query.get("disposition", "attachment").lower().strip() + if disposition not in {"inline", "attachment"}: + disposition = "attachment" + + try: + abs_path, content_type, filename = manager.resolve_asset_content_for_download( + asset_info_id=str(uuid.UUID(request.match_info["id"])), + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve)) + except NotImplementedError as nie: + return _error_response(501, "BACKEND_UNSUPPORTED", str(nie)) + except FileNotFoundError: + return _error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") + + quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'") + cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' + + file_size = os.path.getsize(abs_path) + logging.info( + "download_asset_content: path=%s, size=%d bytes (%.2f MB), content_type=%s, filename=%s", + abs_path, + file_size, + file_size / (1024 * 1024), + content_type, + filename, + ) + + async def file_sender(): + chunk_size = 64 * 1024 + with open(abs_path, "rb") as f: + while True: + chunk = f.read(chunk_size) + if not chunk: + break + yield chunk + + return web.Response( + body=file_sender(), + content_type=content_type, + headers={ + "Content-Disposition": cd, + "Content-Length": str(file_size), + }, + ) + + +@ROUTES.post("/api/assets/from-hash") +async def create_asset_from_hash(request: web.Request) -> web.Response: + try: + payload = await request.json() + body = schemas_in.CreateFromHashBody.model_validate(payload) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + result = manager.create_asset_from_hash( + hash_str=body.hash, + name=body.name, + tags=body.tags, + user_metadata=body.user_metadata, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + if result is None: + return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") + return web.json_response(result.model_dump(mode="json"), status=201) + + +@ROUTES.post("/api/assets") +async def upload_asset(request: web.Request) -> web.Response: + """Multipart/form-data endpoint for Asset uploads.""" + if not (request.content_type or "").lower().startswith("multipart/"): + return _error_response(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.") + + reader = await request.multipart() + + file_present = False + file_client_name: str | None = None + tags_raw: list[str] = [] + provided_name: str | None = None + user_metadata_raw: str | None = None + provided_hash: str | None = None + provided_hash_exists: bool | None = None + + file_written = 0 + tmp_path: str | None = None + while True: + field = await reader.next() + if field is None: + break + + fname = getattr(field, "name", "") or "" + + if fname == "hash": + try: + s = ((await field.text()) or "").strip().lower() + except Exception: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + + if s: + if ":" not in s: + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + return _error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + provided_hash = f"{algo}:{digest}" + try: + provided_hash_exists = manager.asset_exists(asset_hash=provided_hash) + except Exception: + provided_hash_exists = None # do not fail the whole request here + + elif fname == "file": + file_present = True + file_client_name = (field.filename or "").strip() + + if provided_hash and provided_hash_exists is True: + # If client supplied a hash that we know exists, drain but do not write to disk + try: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + file_written += len(chunk) + except Exception: + return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.") + continue # Do not create temp file; we will create AssetInfo from the existing content + + # Otherwise, store to temp for hashing/ingest + uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") + unique_dir = os.path.join(uploads_root, uuid.uuid4().hex) + os.makedirs(unique_dir, exist_ok=True) + tmp_path = os.path.join(unique_dir, ".upload.part") + + try: + with open(tmp_path, "wb") as f: + while True: + chunk = await field.read_chunk(8 * 1024 * 1024) + if not chunk: + break + f.write(chunk) + file_written += len(chunk) + except Exception: + try: + if os.path.exists(tmp_path or ""): + os.remove(tmp_path) + finally: + return _error_response(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.") + elif fname == "tags": + tags_raw.append((await field.text()) or "") + elif fname == "name": + provided_name = (await field.text()) or None + elif fname == "user_metadata": + user_metadata_raw = (await field.text()) or None + + # If client did not send file, and we are not doing a from-hash fast path -> error + if not file_present and not (provided_hash and provided_hash_exists): + return _error_response(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.") + + if file_present and file_written == 0 and not (provided_hash and provided_hash_exists): + # Empty upload is only acceptable if we are fast-pathing from existing hash + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + finally: + return _error_response(400, "EMPTY_UPLOAD", "Uploaded file is empty.") + + try: + spec = schemas_in.UploadAssetSpec.model_validate({ + "tags": tags_raw, + "name": provided_name, + "user_metadata": user_metadata_raw, + "hash": provided_hash, + }) + except ValidationError as ve: + try: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + finally: + return _validation_error_response("INVALID_BODY", ve) + + # Validate models category against configured folders (consistent with previous behavior) + if spec.tags and spec.tags[0] == "models": + if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + return _error_response( + 400, "INVALID_BODY", f"unknown models category '{spec.tags[1] if len(spec.tags) >= 2 else ''}'" + ) + + owner_id = USER_MANAGER.get_request_user_id(request) + + # Fast path: if a valid provided hash exists, create AssetInfo without writing anything + if spec.hash and provided_hash_exists is True: + try: + result = manager.create_asset_from_hash( + hash_str=spec.hash, + name=spec.name or (spec.hash.split(":", 1)[1]), + tags=spec.tags, + user_metadata=spec.user_metadata or {}, + owner_id=owner_id, + ) + except Exception: + logging.exception("create_asset_from_hash failed for hash=%s, owner_id=%s", spec.hash, owner_id) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + if result is None: + return _error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist") + + # Drain temp if we accidentally saved (e.g., hash field came after file) + if tmp_path and os.path.exists(tmp_path): + with contextlib.suppress(Exception): + os.remove(tmp_path) + + status = 200 if (not result.created_new) else 201 + return web.json_response(result.model_dump(mode="json"), status=status) + + # Otherwise, we must have a temp file path to ingest + if not tmp_path or not os.path.exists(tmp_path): + # The only case we reach here without a temp file is: client sent a hash that does not exist and no file + return _error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.") + + try: + created = manager.upload_asset_from_temp_path( + spec, + temp_path=tmp_path, + client_filename=file_client_name, + owner_id=owner_id, + expected_asset_hash=spec.hash, + ) + status = 201 if created.created_new else 200 + return web.json_response(created.model_dump(mode="json"), status=status) + except ValueError as e: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + msg = str(e) + if "HASH_MISMATCH" in msg or msg.strip().upper() == "HASH_MISMATCH": + return _error_response( + 400, + "HASH_MISMATCH", + "Uploaded file hash does not match provided hash.", + ) + return _error_response(400, "BAD_REQUEST", "Invalid inputs.") + except Exception: + if tmp_path and os.path.exists(tmp_path): + os.remove(tmp_path) + logging.exception("upload_asset_from_temp_path failed for tmp_path=%s, owner_id=%s", tmp_path, owner_id) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + +@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") +async def update_asset(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + body = schemas_in.UpdateAssetBody.model_validate(await request.json()) + except ValidationError as ve: + return _validation_error_response("INVALID_BODY", ve) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = manager.update_asset( + asset_info_id=asset_info_id, + name=body.name, + user_metadata=body.user_metadata, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except (ValueError, PermissionError) as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + logging.exception( + "update_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") +async def delete_asset(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + delete_content = request.query.get("delete_content") + delete_content = True if delete_content is None else delete_content.lower() not in {"0", "false", "no"} + + try: + deleted = manager.delete_asset_reference( + asset_info_id=asset_info_id, + owner_id=USER_MANAGER.get_request_user_id(request), + delete_content_if_orphan=delete_content, + ) + except Exception: + logging.exception( + "delete_asset_reference failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + if not deleted: + return _error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.") + return web.Response(status=204) + + @ROUTES.get("/api/tags") async def get_tags(request: web.Request) -> web.Response: """ @@ -100,3 +429,86 @@ async def get_tags(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), ) return web.json_response(result.model_dump(mode="json")) + + +@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") +async def add_asset_tags(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + payload = await request.json() + data = schemas_in.TagsAdd.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = manager.add_tags_to_asset( + asset_info_id=asset_info_id, + tags=data.tags, + origin="manual", + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except (ValueError, PermissionError) as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + logging.exception( + "add_tags_to_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") +async def delete_asset_tags(request: web.Request) -> web.Response: + asset_info_id = str(uuid.UUID(request.match_info["id"])) + try: + payload = await request.json() + data = schemas_in.TagsRemove.model_validate(payload) + except ValidationError as ve: + return _error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()}) + except Exception: + return _error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + + try: + result = manager.remove_tags_from_asset( + asset_info_id=asset_info_id, + tags=data.tags, + owner_id=USER_MANAGER.get_request_user_id(request), + ) + except ValueError as ve: + return _error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + except Exception: + logging.exception( + "remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s", + asset_info_id, + USER_MANAGER.get_request_user_id(request), + ) + return _error_response(500, "INTERNAL", "Unexpected server error.") + + return web.json_response(result.model_dump(mode="json"), status=200) + + +@ROUTES.post("/api/assets/seed") +async def seed_assets_endpoint(request: web.Request) -> web.Response: + """Trigger asset seeding for specified roots (models, input, output).""" + try: + payload = await request.json() + roots = payload.get("roots", ["models", "input", "output"]) + except Exception: + roots = ["models", "input", "output"] + + valid_roots = [r for r in roots if r in ("models", "input", "output")] + if not valid_roots: + return _error_response(400, "INVALID_BODY", "No valid roots specified") + + try: + seed_assets(tuple(valid_roots)) + except Exception: + logging.exception("seed_assets failed for roots=%s", valid_roots) + return _error_response(500, "INTERNAL", "Seed operation failed") + + return web.json_response({"seeded": valid_roots}, status=200) diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 200b41aef..6707ffb0c 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -1,5 +1,4 @@ import json -import uuid from typing import Any, Literal from pydantic import ( @@ -8,9 +7,9 @@ from pydantic import ( Field, conint, field_validator, + model_validator, ) - class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) @@ -57,6 +56,57 @@ class ListAssetsQuery(BaseModel): return None +class UpdateAssetBody(BaseModel): + name: str | None = None + user_metadata: dict[str, Any] | None = None + + @model_validator(mode="after") + def _at_least_one(self): + if self.name is None and self.user_metadata is None: + raise ValueError("Provide at least one of: name, user_metadata.") + return self + + +class CreateFromHashBody(BaseModel): + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + hash: str + name: str + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + + @field_validator("hash") + @classmethod + def _require_blake3(cls, v): + s = (v or "").strip().lower() + if ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3": + raise ValueError("only canonical 'blake3:' is accepted here") + if not digest or any(c for c in digest if c not in "0123456789abcdef"): + raise ValueError("hash digest must be lowercase hex") + return s + + @field_validator("tags", mode="before") + @classmethod + def _tags_norm(cls, v): + if v is None: + return [] + if isinstance(v, list): + out = [str(t).strip().lower() for t in v if str(t).strip()] + seen = set() + dedup = [] + for t in out: + if t not in seen: + seen.add(t) + dedup.append(t) + return dedup + if isinstance(v, str): + return [t.strip().lower() for t in v.split(",") if t.strip()] + return [] + + class TagsListQuery(BaseModel): model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) @@ -75,20 +125,140 @@ class TagsListQuery(BaseModel): return v.lower() or None -class SetPreviewBody(BaseModel): - """Set or clear the preview for an AssetInfo. Provide an Asset.id or null.""" - preview_id: str | None = None +class TagsAdd(BaseModel): + model_config = ConfigDict(extra="ignore") + tags: list[str] = Field(..., min_length=1) - @field_validator("preview_id", mode="before") + @field_validator("tags") @classmethod - def _norm_uuid(cls, v): + def normalize_tags(cls, v: list[str]) -> list[str]: + out = [] + for t in v: + if not isinstance(t, str): + raise TypeError("tags must be strings") + tnorm = t.strip().lower() + if tnorm: + out.append(tnorm) + seen = set() + deduplicated = [] + for x in out: + if x not in seen: + seen.add(x) + deduplicated.append(x) + return deduplicated + + +class TagsRemove(TagsAdd): + pass + + +class UploadAssetSpec(BaseModel): + """Upload Asset operation. + - tags: ordered; first is root ('models'|'input'|'output'); + if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths + - name: display name + - user_metadata: arbitrary JSON object (optional) + - hash: optional canonical 'blake3:' provided by the client for validation / fast-path + + Files created via this endpoint are stored on disk using the **content hash** as the filename stem + and the original extension is preserved when available. + """ + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) + + tags: list[str] = Field(..., min_length=1) + name: str | None = Field(default=None, max_length=512, description="Display Name") + user_metadata: dict[str, Any] = Field(default_factory=dict) + hash: str | None = Field(default=None) + + @field_validator("hash", mode="before") + @classmethod + def _parse_hash(cls, v): if v is None: return None - s = str(v).strip() + s = str(v).strip().lower() if not s: return None - try: - uuid.UUID(s) - except Exception: - raise ValueError("preview_id must be a UUID") - return s + if ":" not in s: + raise ValueError("hash must be 'blake3:'") + algo, digest = s.split(":", 1) + if algo != "blake3": + raise ValueError("only canonical 'blake3:' is accepted here") + if not digest or any(c for c in digest if c not in "0123456789abcdef"): + raise ValueError("hash digest must be lowercase hex") + return f"{algo}:{digest}" + + @field_validator("tags", mode="before") + @classmethod + def _parse_tags(cls, v): + """ + Accepts a list of strings (possibly multiple form fields), + where each string can be: + - JSON array (e.g., '["models","loras","foo"]') + - comma-separated ('models, loras, foo') + - single token ('models') + Returns a normalized, deduplicated, ordered list. + """ + items: list[str] = [] + if v is None: + return [] + if isinstance(v, str): + v = [v] + + if isinstance(v, list): + for item in v: + if item is None: + continue + s = str(item).strip() + if not s: + continue + if s.startswith("["): + try: + arr = json.loads(s) + if isinstance(arr, list): + items.extend(str(x) for x in arr) + continue + except Exception: + pass # fallback to CSV parse below + items.extend([p for p in s.split(",") if p.strip()]) + else: + return [] + + # normalize + dedupe + norm = [] + seen = set() + for t in items: + tnorm = str(t).strip().lower() + if tnorm and tnorm not in seen: + seen.add(tnorm) + norm.append(tnorm) + return norm + + @field_validator("user_metadata", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v or {} + if isinstance(v, str): + s = v.strip() + if not s: + return {} + try: + parsed = json.loads(s) + except Exception as e: + raise ValueError(f"user_metadata must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("user_metadata must be a JSON object") + return parsed + return {} + + @model_validator(mode="after") + def _validate_order(self): + if not self.tags: + raise ValueError("tags must be provided and non-empty") + root = self.tags[0] + if root not in {"models", "input", "output"}: + raise ValueError("first tag must be one of: models, input, output") + if root == "models": + if len(self.tags) < 2: + raise ValueError("models uploads require a category tag as the second tag") + return self diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index 9f8184f20..b6fb3da0c 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -29,6 +29,21 @@ class AssetsList(BaseModel): has_more: bool +class AssetUpdated(BaseModel): + id: str + name: str + asset_hash: str | None = None + tags: list[str] = Field(default_factory=list) + user_metadata: dict[str, Any] = Field(default_factory=dict) + updated_at: datetime | None = None + + model_config = ConfigDict(from_attributes=True) + + @field_serializer("updated_at") + def _ser_updated(self, v: datetime | None, _info): + return v.isoformat() if v else None + + class AssetDetail(BaseModel): id: str name: str @@ -48,6 +63,10 @@ class AssetDetail(BaseModel): return v.isoformat() if v else None +class AssetCreated(AssetDetail): + created_new: bool + + class TagUsage(BaseModel): name: str count: int @@ -58,3 +77,17 @@ class TagsList(BaseModel): tags: list[TagUsage] = Field(default_factory=list) total: int has_more: bool + + +class TagsAdd(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + added: list[str] = Field(default_factory=list) + already_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) + + +class TagsRemove(BaseModel): + model_config = ConfigDict(str_strip_whitespace=True) + removed: list[str] = Field(default_factory=list) + not_present: list[str] = Field(default_factory=list) + total_tags: list[str] = Field(default_factory=list) diff --git a/app/assets/database/queries.py b/app/assets/database/queries.py index 0824c0c2f..d6b33ec7b 100644 --- a/app/assets/database/queries.py +++ b/app/assets/database/queries.py @@ -1,9 +1,17 @@ +import os +import logging import sqlalchemy as sa from collections import defaultdict -from sqlalchemy import select, exists, func +from datetime import datetime +from typing import Iterable, Any +from sqlalchemy import select, delete, exists, func +from sqlalchemy.dialects import sqlite +from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, contains_eager, noload -from app.assets.database.models import Asset, AssetInfo, AssetInfoMeta, AssetInfoTag, Tag -from app.assets.helpers import escape_like_prefix, normalize_tags +from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag +from app.assets.helpers import ( + compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow +) from typing import Sequence @@ -15,6 +23,22 @@ def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: return AssetInfo.owner_id.in_(["", owner_id]) +def pick_best_live_path(states: Sequence[AssetCacheState]) -> str: + """ + Return the best on-disk path among cache states: + 1) Prefer a path that exists with needs_verify == False (already verified). + 2) Otherwise, pick the first path that exists. + 3) Otherwise return empty string. + """ + alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] + if not alive: + return "" + for s in alive: + if not getattr(s, "needs_verify", False): + return s.file_path + return alive[0].file_path + + def apply_tag_filters( stmt: sa.sql.Select, include_tags: Sequence[str] | None = None, @@ -42,6 +66,7 @@ def apply_tag_filters( ) return stmt + def apply_metadata_filter( stmt: sa.sql.Select, metadata_filter: dict | None = None, @@ -94,7 +119,11 @@ def apply_metadata_filter( return stmt -def asset_exists_by_hash(session: Session, asset_hash: str) -> bool: +def asset_exists_by_hash( + session: Session, + *, + asset_hash: str, +) -> bool: """ Check if an asset with a given hash exists in database. """ @@ -105,9 +134,39 @@ def asset_exists_by_hash(session: Session, asset_hash: str) -> bool: ).first() return row is not None -def get_asset_info_by_id(session: Session, asset_info_id: str) -> AssetInfo | None: + +def asset_info_exists_for_asset_id( + session: Session, + *, + asset_id: str, +) -> bool: + q = ( + select(sa.literal(True)) + .select_from(AssetInfo) + .where(AssetInfo.asset_id == asset_id) + .limit(1) + ) + return (session.execute(q)).first() is not None + + +def get_asset_by_hash( + session: Session, + *, + asset_hash: str, +) -> Asset | None: + return ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + + +def get_asset_info_by_id( + session: Session, + *, + asset_info_id: str, +) -> AssetInfo | None: return session.get(AssetInfo, asset_info_id) + def list_asset_infos_page( session: Session, owner_id: str = "", @@ -171,12 +230,14 @@ def list_asset_infos_page( select(AssetInfoTag.asset_info_id, Tag.name) .join(Tag, Tag.name == AssetInfoTag.tag_name) .where(AssetInfoTag.asset_info_id.in_(id_list)) + .order_by(AssetInfoTag.added_at) ) for aid, tag_name in rows.all(): tag_map[aid].append(tag_name) return infos, tag_map, total + def fetch_asset_info_asset_and_tags( session: Session, asset_info_id: str, @@ -208,6 +269,494 @@ def fetch_asset_info_asset_and_tags( tags.append(tag_name) return first_info, first_asset, tags + +def fetch_asset_info_and_asset( + session: Session, + *, + asset_info_id: str, + owner_id: str = "", +) -> tuple[AssetInfo, Asset] | None: + stmt = ( + select(AssetInfo, Asset) + .join(Asset, Asset.id == AssetInfo.asset_id) + .where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + .limit(1) + .options(noload(AssetInfo.tags)) + ) + row = session.execute(stmt) + pair = row.first() + if not pair: + return None + return pair[0], pair[1] + +def list_cache_states_by_asset_id( + session: Session, *, asset_id: str +) -> Sequence[AssetCacheState]: + return ( + session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_id == asset_id) + .order_by(AssetCacheState.id.asc()) + ) + ).scalars().all() + + +def touch_asset_info_by_id( + session: Session, + *, + asset_info_id: str, + ts: datetime | None = None, + only_if_newer: bool = True, +) -> None: + ts = ts or utcnow() + stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) + if only_if_newer: + stmt = stmt.where( + sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) + ) + session.execute(stmt.values(last_access_time=ts)) + + +def create_asset_info_for_existing_asset( + session: Session, + *, + asset_hash: str, + name: str, + user_metadata: dict | None = None, + tags: Sequence[str] | None = None, + tag_origin: str = "manual", + owner_id: str = "", +) -> AssetInfo: + """Create or return an existing AssetInfo for an Asset identified by asset_hash.""" + now = utcnow() + asset = get_asset_by_hash(session, asset_hash=asset_hash) + if not asset: + raise ValueError(f"Unknown asset hash {asset_hash}") + + info = AssetInfo( + owner_id=owner_id, + name=name, + asset_id=asset.id, + preview_id=None, + created_at=now, + updated_at=now, + last_access_time=now, + ) + try: + with session.begin_nested(): + session.add(info) + session.flush() + except IntegrityError: + existing = ( + session.execute( + select(AssetInfo) + .options(noload(AssetInfo.tags)) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == name, + AssetInfo.owner_id == owner_id, + ) + .limit(1) + ) + ).unique().scalars().first() + if not existing: + raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.") + return existing + + # metadata["filename"] hack + new_meta = dict(user_metadata or {}) + computed_filename = None + try: + p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) + if p: + computed_filename = compute_relative_filename(p) + except Exception: + computed_filename = None + if computed_filename: + new_meta["filename"] = computed_filename + if new_meta: + replace_asset_info_metadata_projection( + session, + asset_info_id=info.id, + user_metadata=new_meta, + ) + + if tags is not None: + set_asset_info_tags( + session, + asset_info_id=info.id, + tags=tags, + origin=tag_origin, + ) + return info + + +def set_asset_info_tags( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", +) -> dict: + desired = normalize_tags(tags) + + current = set( + tag_name for (tag_name,) in ( + session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) + ).all() + ) + + to_add = [t for t in desired if t not in current] + to_remove = [t for t in current if t not in desired] + + if to_add: + ensure_tags_exist(session, to_add, tag_type="user") + session.add_all([ + AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow()) + for t in to_add + ]) + session.flush() + + if to_remove: + session.execute( + delete(AssetInfoTag) + .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) + ) + session.flush() + + return {"added": to_add, "removed": to_remove, "total": desired} + + +def replace_asset_info_metadata_projection( + session: Session, + *, + asset_info_id: str, + user_metadata: dict | None = None, +) -> None: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info.user_metadata = user_metadata or {} + info.updated_at = utcnow() + session.flush() + + session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) + session.flush() + + if not user_metadata: + return + + rows: list[AssetInfoMeta] = [] + for k, v in user_metadata.items(): + for r in project_kv(k, v): + rows.append( + AssetInfoMeta( + asset_info_id=asset_info_id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + session.flush() + + +def ingest_fs_asset( + session: Session, + *, + asset_hash: str, + abs_path: str, + size_bytes: int, + mtime_ns: int, + mime_type: str | None = None, + info_name: str | None = None, + owner_id: str = "", + preview_id: str | None = None, + user_metadata: dict | None = None, + tags: Sequence[str] = (), + tag_origin: str = "manual", + require_existing_tags: bool = False, +) -> dict: + """ + Idempotently upsert: + - Asset by content hash (create if missing) + - AssetCacheState(file_path) pointing to asset_id + - Optionally AssetInfo + tag links and metadata projection + Returns flags and ids. + """ + locator = os.path.abspath(abs_path) + now = utcnow() + + if preview_id: + if not session.get(Asset, preview_id): + preview_id = None + + out: dict[str, Any] = { + "asset_created": False, + "asset_updated": False, + "state_created": False, + "state_updated": False, + "asset_info_id": None, + } + + # 1) Asset by hash + asset = ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + ).scalars().first() + if not asset: + vals = { + "hash": asset_hash, + "size_bytes": int(size_bytes), + "mime_type": mime_type, + "created_at": now, + } + res = session.execute( + sqlite.insert(Asset) + .values(**vals) + .on_conflict_do_nothing(index_elements=[Asset.hash]) + ) + if int(res.rowcount or 0) > 0: + out["asset_created"] = True + asset = ( + session.execute( + select(Asset).where(Asset.hash == asset_hash).limit(1) + ) + ).scalars().first() + if not asset: + raise RuntimeError("Asset row not found after upsert.") + else: + changed = False + if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: + asset.size_bytes = int(size_bytes) + changed = True + if mime_type and asset.mime_type != mime_type: + asset.mime_type = mime_type + changed = True + if changed: + out["asset_updated"] = True + + # 2) AssetCacheState upsert by file_path (unique) + vals = { + "asset_id": asset.id, + "file_path": locator, + "mtime_ns": int(mtime_ns), + } + ins = ( + sqlite.insert(AssetCacheState) + .values(**vals) + .on_conflict_do_nothing(index_elements=[AssetCacheState.file_path]) + ) + + res = session.execute(ins) + if int(res.rowcount or 0) > 0: + out["state_created"] = True + else: + upd = ( + sa.update(AssetCacheState) + .where(AssetCacheState.file_path == locator) + .where( + sa.or_( + AssetCacheState.asset_id != asset.id, + AssetCacheState.mtime_ns.is_(None), + AssetCacheState.mtime_ns != int(mtime_ns), + ) + ) + .values(asset_id=asset.id, mtime_ns=int(mtime_ns)) + ) + res2 = session.execute(upd) + if int(res2.rowcount or 0) > 0: + out["state_updated"] = True + + # 3) Optional AssetInfo + tags + metadata + if info_name: + try: + with session.begin_nested(): + info = AssetInfo( + owner_id=owner_id, + name=info_name, + asset_id=asset.id, + preview_id=preview_id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(info) + session.flush() + out["asset_info_id"] = info.id + except IntegrityError: + pass + + existing_info = ( + session.execute( + select(AssetInfo) + .where( + AssetInfo.asset_id == asset.id, + AssetInfo.name == info_name, + (AssetInfo.owner_id == owner_id), + ) + .limit(1) + ) + ).unique().scalar_one_or_none() + if not existing_info: + raise RuntimeError("Failed to update or insert AssetInfo.") + + if preview_id and existing_info.preview_id != preview_id: + existing_info.preview_id = preview_id + + existing_info.updated_at = now + if existing_info.last_access_time < now: + existing_info.last_access_time = now + session.flush() + out["asset_info_id"] = existing_info.id + + norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()] + if norm and out["asset_info_id"] is not None: + if not require_existing_tags: + ensure_tags_exist(session, norm, tag_type="user") + + existing_tag_names = set( + name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all() + ) + missing = [t for t in norm if t not in existing_tag_names] + if missing and require_existing_tags: + raise ValueError(f"Unknown tags: {missing}") + + existing_links = set( + tag_name + for (tag_name,) in ( + session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"]) + ) + ).all() + ) + to_add = [t for t in norm if t in existing_tag_names and t not in existing_links] + if to_add: + session.add_all( + [ + AssetInfoTag( + asset_info_id=out["asset_info_id"], + tag_name=t, + origin=tag_origin, + added_at=now, + ) + for t in to_add + ] + ) + session.flush() + + # metadata["filename"] hack + if out["asset_info_id"] is not None: + primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id)) + computed_filename = compute_relative_filename(primary_path) if primary_path else None + + current_meta = existing_info.user_metadata or {} + new_meta = dict(current_meta) + if user_metadata is not None: + for k, v in user_metadata.items(): + new_meta[k] = v + if computed_filename: + new_meta["filename"] = computed_filename + + if new_meta != current_meta: + replace_asset_info_metadata_projection( + session, + asset_info_id=out["asset_info_id"], + user_metadata=new_meta, + ) + + try: + remove_missing_tag_for_asset_id(session, asset_id=asset.id) + except Exception: + logging.exception("Failed to clear 'missing' tag for asset %s", asset.id) + return out + + +def update_asset_info_full( + session: Session, + *, + asset_info_id: str, + name: str | None = None, + tags: Sequence[str] | None = None, + user_metadata: dict | None = None, + tag_origin: str = "manual", + asset_info_row: Any = None, +) -> AssetInfo: + if not asset_info_row: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + else: + info = asset_info_row + + touched = False + if name is not None and name != info.name: + info.name = name + touched = True + + computed_filename = None + try: + p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id)) + if p: + computed_filename = compute_relative_filename(p) + except Exception: + computed_filename = None + + if user_metadata is not None: + new_meta = dict(user_metadata) + if computed_filename: + new_meta["filename"] = computed_filename + replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + else: + if computed_filename: + current_meta = info.user_metadata or {} + if current_meta.get("filename") != computed_filename: + new_meta = dict(current_meta) + new_meta["filename"] = computed_filename + replace_asset_info_metadata_projection( + session, asset_info_id=asset_info_id, user_metadata=new_meta + ) + touched = True + + if tags is not None: + set_asset_info_tags( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=tag_origin, + ) + touched = True + + if touched and user_metadata is None: + info.updated_at = utcnow() + session.flush() + + return info + + +def delete_asset_info_by_id( + session: Session, + *, + asset_info_id: str, + owner_id: str, +) -> bool: + stmt = sa.delete(AssetInfo).where( + AssetInfo.id == asset_info_id, + visible_owner_clause(owner_id), + ) + return int((session.execute(stmt)).rowcount or 0) > 0 + + def list_tags_with_usage( session: Session, prefix: str | None = None, @@ -265,3 +814,163 @@ def list_tags_with_usage( rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows] return rows_norm, int(total or 0) + + +def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: + wanted = normalize_tags(list(names)) + if not wanted: + return + rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))] + ins = ( + sqlite.insert(Tag) + .values(rows) + .on_conflict_do_nothing(index_elements=[Tag.name]) + ) + session.execute(ins) + + +def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]: + return [ + tag_name for (tag_name,) in ( + session.execute( + select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + ] + + +def add_tags_to_asset_info( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], + origin: str = "manual", + create_if_missing: bool = True, + asset_info_row: Any = None, +) -> dict: + if not asset_info_row: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"added": [], "already_present": [], "total_tags": total} + + if create_if_missing: + ensure_tags_exist(session, norm, tag_type="user") + + current = { + tag_name + for (tag_name,) in ( + session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + want = set(norm) + to_add = sorted(want - current) + + if to_add: + with session.begin_nested() as nested: + try: + session.add_all( + [ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_at=utcnow(), + ) + for t in to_add + ] + ) + session.flush() + except IntegrityError: + nested.rollback() + + after = set(get_asset_tags(session, asset_info_id=asset_info_id)) + return { + "added": sorted(((after - current) & want)), + "already_present": sorted(want & current), + "total_tags": sorted(after), + } + + +def remove_tags_from_asset_info( + session: Session, + *, + asset_info_id: str, + tags: Sequence[str], +) -> dict: + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + norm = normalize_tags(tags) + if not norm: + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": [], "not_present": [], "total_tags": total} + + existing = { + tag_name + for (tag_name,) in ( + session.execute( + sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + ) + ).all() + } + + to_remove = sorted(set(t for t in norm if t in existing)) + not_present = sorted(set(t for t in norm if t not in existing)) + + if to_remove: + session.execute( + delete(AssetInfoTag) + .where( + AssetInfoTag.asset_info_id == asset_info_id, + AssetInfoTag.tag_name.in_(to_remove), + ) + ) + session.flush() + + total = get_asset_tags(session, asset_info_id=asset_info_id) + return {"removed": to_remove, "not_present": not_present, "total_tags": total} + + +def remove_missing_tag_for_asset_id( + session: Session, + *, + asset_id: str, +) -> None: + session.execute( + sa.delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), + AssetInfoTag.tag_name == "missing", + ) + ) + + +def set_asset_info_preview( + session: Session, + *, + asset_info_id: str, + preview_asset_id: str | None = None, +) -> None: + """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" + info = session.get(AssetInfo, asset_info_id) + if not info: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + if preview_asset_id is None: + info.preview_id = None + else: + # validate preview asset exists + if not session.get(Asset, preview_asset_id): + raise ValueError(f"Preview Asset {preview_asset_id} not found") + info.preview_id = preview_asset_id + + info.updated_at = utcnow() + session.flush() diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 08b465b5a..5030b123a 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -1,5 +1,6 @@ import contextlib import os +from decimal import Decimal from aiohttp import web from datetime import datetime, timezone from pathlib import Path @@ -87,6 +88,40 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]: targets.append((name, paths)) return targets +def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: + """Validates and maps tags -> (base_dir, subdirs_for_fs)""" + root = tags[0] + if root == "models": + if len(tags) < 2: + raise ValueError("at least two tags required for model asset") + try: + bases = folder_paths.folder_names_and_paths[tags[1]][0] + except KeyError: + raise ValueError(f"unknown model category '{tags[1]}'") + if not bases: + raise ValueError(f"no base path configured for category '{tags[1]}'") + base_dir = os.path.abspath(bases[0]) + raw_subdirs = tags[2:] + else: + base_dir = os.path.abspath( + folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory() + ) + raw_subdirs = tags[1:] + for i in raw_subdirs: + if i in (".", ".."): + raise ValueError("invalid path component in tags") + + return base_dir, raw_subdirs if raw_subdirs else [] + +def ensure_within_base(candidate: str, base: str) -> None: + cand_abs = os.path.abspath(candidate) + base_abs = os.path.abspath(base) + try: + if os.path.commonpath([cand_abs, base_abs]) != base_abs: + raise ValueError("destination escapes base directory") + except Exception: + raise ValueError("invalid destination path") + def compute_relative_filename(file_path: str) -> str | None: """ Return the model's path relative to the last well-known folder (the model category), @@ -113,7 +148,6 @@ def compute_relative_filename(file_path: str) -> str | None: return "/".join(inside) return "/".join(parts) # input/output: keep all parts - def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]: """Given an absolute or relative file path, determine which root category the path belongs to: - 'input' if the file resides under `folder_paths.get_input_directory()` @@ -215,3 +249,64 @@ def collect_models_files() -> list[str]: if allowed: out.append(abs_path) return out + +def is_scalar(v): + if v is None: + return True + if isinstance(v, bool): + return True + if isinstance(v, (int, float, Decimal, str)): + return True + return False + +def project_kv(key: str, value): + """ + Turn a metadata key/value into typed projection rows. + Returns list[dict] with keys: + key, ordinal, and one of val_str / val_num / val_bool / val_json (others None) + """ + rows: list[dict] = [] + + def _null_row(ordinal: int) -> dict: + return { + "key": key, "ordinal": ordinal, + "val_str": None, "val_num": None, "val_bool": None, "val_json": None + } + + if value is None: + rows.append(_null_row(0)) + return rows + + if is_scalar(value): + if isinstance(value, bool): + rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)}) + elif isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + rows.append({"key": key, "ordinal": 0, "val_num": num}) + elif isinstance(value, str): + rows.append({"key": key, "ordinal": 0, "val_str": value}) + else: + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows + + if isinstance(value, list): + if all(is_scalar(x) for x in value): + for i, x in enumerate(value): + if x is None: + rows.append(_null_row(i)) + elif isinstance(x, bool): + rows.append({"key": key, "ordinal": i, "val_bool": bool(x)}) + elif isinstance(x, (int, float, Decimal)): + num = x if isinstance(x, Decimal) else Decimal(str(x)) + rows.append({"key": key, "ordinal": i, "val_num": num}) + elif isinstance(x, str): + rows.append({"key": key, "ordinal": i, "val_str": x}) + else: + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + for i, x in enumerate(value): + rows.append({"key": key, "ordinal": i, "val_json": x}) + return rows + + rows.append({"key": key, "ordinal": 0, "val_json": value}) + return rows diff --git a/app/assets/manager.py b/app/assets/manager.py index 6425e7aa2..a68c8c8ae 100644 --- a/app/assets/manager.py +++ b/app/assets/manager.py @@ -1,13 +1,33 @@ +import os +import mimetypes +import contextlib from typing import Sequence from app.database.db import create_session -from app.assets.api import schemas_out +from app.assets.api import schemas_out, schemas_in from app.assets.database.queries import ( asset_exists_by_hash, + asset_info_exists_for_asset_id, + get_asset_by_hash, + get_asset_info_by_id, fetch_asset_info_asset_and_tags, + fetch_asset_info_and_asset, + create_asset_info_for_existing_asset, + touch_asset_info_by_id, + update_asset_info_full, + delete_asset_info_by_id, + list_cache_states_by_asset_id, list_asset_infos_page, list_tags_with_usage, + get_asset_tags, + add_tags_to_asset_info, + remove_tags_from_asset_info, + pick_best_live_path, + ingest_fs_asset, + set_asset_info_preview, ) +from app.assets.helpers import resolve_destination_from_tags, ensure_within_base +from app.assets.database.models import Asset def _safe_sort_field(requested: str | None) -> str: @@ -19,11 +39,28 @@ def _safe_sort_field(requested: str | None) -> str: return "created_at" -def asset_exists(asset_hash: str) -> bool: +def _get_size_mtime_ns(path: str) -> tuple[int, int]: + st = os.stat(path, follow_symlinks=True) + return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000)) + + +def _safe_filename(name: str | None, fallback: str) -> str: + n = os.path.basename((name or "").strip() or fallback) + if n: + return n + return fallback + + +def asset_exists(*, asset_hash: str) -> bool: + """ + Check if an asset with a given hash exists in database. + """ with create_session() as session: return asset_exists_by_hash(session, asset_hash=asset_hash) + def list_assets( + *, include_tags: Sequence[str] | None = None, exclude_tags: Sequence[str] | None = None, name_contains: str | None = None, @@ -63,7 +100,6 @@ def list_assets( size=int(asset.size_bytes) if asset else None, mime_type=asset.mime_type if asset else None, tags=tags, - preview_url=f"/api/assets/{info.id}/content", created_at=info.created_at, updated_at=info.updated_at, last_access_time=info.last_access_time, @@ -76,7 +112,12 @@ def list_assets( has_more=(offset + len(summaries)) < total, ) -def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail: + +def get_asset( + *, + asset_info_id: str, + owner_id: str = "", +) -> schemas_out.AssetDetail: with create_session() as session: res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) if not res: @@ -97,6 +138,358 @@ def get_asset(asset_info_id: str, owner_id: str = "") -> schemas_out.AssetDetail last_access_time=info.last_access_time, ) + +def resolve_asset_content_for_download( + *, + asset_info_id: str, + owner_id: str = "", +) -> tuple[str, str, str]: + with create_session() as session: + pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not pair: + raise ValueError(f"AssetInfo {asset_info_id} not found") + + info, asset = pair + states = list_cache_states_by_asset_id(session, asset_id=asset.id) + abs_path = pick_best_live_path(states) + if not abs_path: + raise FileNotFoundError + + touch_asset_info_by_id(session, asset_info_id=asset_info_id) + session.commit() + + ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" + download_name = info.name or os.path.basename(abs_path) + return abs_path, ctype, download_name + + +def upload_asset_from_temp_path( + spec: schemas_in.UploadAssetSpec, + *, + temp_path: str, + client_filename: str | None = None, + owner_id: str = "", + expected_asset_hash: str | None = None, +) -> schemas_out.AssetCreated: + """ + Create new asset or update existing asset from a temporary file path. + """ + try: + # NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment + import app.assets.hashing as hashing + digest = hashing.blake3_hash(temp_path) + except Exception as e: + raise RuntimeError(f"failed to hash uploaded file: {e}") + asset_hash = "blake3:" + digest + + if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower(): + raise ValueError("HASH_MISMATCH") + + with create_session() as session: + existing = get_asset_by_hash(session, asset_hash=asset_hash) + if existing is not None: + with contextlib.suppress(Exception): + if temp_path and os.path.exists(temp_path): + os.remove(temp_path) + + display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest) + info = create_asset_info_for_existing_asset( + session, + asset_hash=asset_hash, + name=display_name, + user_metadata=spec.user_metadata or {}, + tags=spec.tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + tag_names = get_asset_tags(session, asset_info_id=info.id) + session.commit() + + return schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=existing.hash, + size=int(existing.size_bytes) if existing.size_bytes is not None else None, + mime_type=existing.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=False, + ) + + base_dir, subdirs = resolve_destination_from_tags(spec.tags) + dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir + os.makedirs(dest_dir, exist_ok=True) + + src_for_ext = (client_filename or spec.name or "").strip() + _ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else "" + ext = _ext if 0 < len(_ext) <= 16 else "" + hashed_basename = f"{digest}{ext}" + dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) + ensure_within_base(dest_abs, base_dir) + + content_type = ( + mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] + or mimetypes.guess_type(hashed_basename, strict=False)[0] + or "application/octet-stream" + ) + + try: + os.replace(temp_path, dest_abs) + except Exception as e: + raise RuntimeError(f"failed to move uploaded file into place: {e}") + + try: + size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs) + except OSError as e: + raise RuntimeError(f"failed to stat destination file: {e}") + + with create_session() as session: + result = ingest_fs_asset( + session, + asset_hash=asset_hash, + abs_path=dest_abs, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest), + owner_id=owner_id, + preview_id=None, + user_metadata=spec.user_metadata or {}, + tags=spec.tags, + tag_origin="manual", + require_existing_tags=False, + ) + info_id = result["asset_info_id"] + if not info_id: + raise RuntimeError("failed to create asset metadata") + + pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + info, asset = pair + tag_names = get_asset_tags(session, asset_info_id=info.id) + created_result = schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=asset.hash, + size=int(asset.size_bytes), + mime_type=asset.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=result["asset_created"], + ) + session.commit() + + return created_result + + +def update_asset( + *, + asset_info_id: str, + name: str | None = None, + tags: list[str] | None = None, + user_metadata: dict | None = None, + owner_id: str = "", +) -> schemas_out.AssetUpdated: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + info = update_asset_info_full( + session, + asset_info_id=asset_info_id, + name=name, + tags=tags, + user_metadata=user_metadata, + tag_origin="manual", + asset_info_row=info_row, + ) + + tag_names = get_asset_tags(session, asset_info_id=asset_info_id) + result = schemas_out.AssetUpdated( + id=info.id, + name=info.name, + asset_hash=info.asset.hash if info.asset else None, + tags=tag_names, + user_metadata=info.user_metadata or {}, + updated_at=info.updated_at, + ) + session.commit() + + return result + + +def set_asset_preview( + *, + asset_info_id: str, + preview_asset_id: str | None = None, + owner_id: str = "", +) -> schemas_out.AssetDetail: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + set_asset_info_preview( + session, + asset_info_id=asset_info_id, + preview_asset_id=preview_asset_id, + ) + + res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not res: + raise RuntimeError("State changed during preview update") + info, asset, tags = res + result = schemas_out.AssetDetail( + id=info.id, + name=info.name, + asset_hash=asset.hash if asset else None, + size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None, + mime_type=asset.mime_type if asset else None, + tags=tags, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + ) + session.commit() + + return result + + +def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + asset_id = info_row.asset_id if info_row else None + deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) + if not deleted: + session.commit() + return False + + if not delete_content_if_orphan or not asset_id: + session.commit() + return True + + still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id) + if still_exists: + session.commit() + return True + + states = list_cache_states_by_asset_id(session, asset_id=asset_id) + file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] + + asset_row = session.get(Asset, asset_id) + if asset_row is not None: + session.delete(asset_row) + + session.commit() + for p in file_paths: + with contextlib.suppress(Exception): + if p and os.path.isfile(p): + os.remove(p) + return True + + +def create_asset_from_hash( + *, + hash_str: str, + name: str, + tags: list[str] | None = None, + user_metadata: dict | None = None, + owner_id: str = "", +) -> schemas_out.AssetCreated | None: + canonical = hash_str.strip().lower() + with create_session() as session: + asset = get_asset_by_hash(session, asset_hash=canonical) + if not asset: + return None + + info = create_asset_info_for_existing_asset( + session, + asset_hash=canonical, + name=_safe_filename(name, fallback=canonical.split(":", 1)[1]), + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + ) + tag_names = get_asset_tags(session, asset_info_id=info.id) + result = schemas_out.AssetCreated( + id=info.id, + name=info.name, + asset_hash=asset.hash, + size=int(asset.size_bytes), + mime_type=asset.mime_type, + tags=tag_names, + user_metadata=info.user_metadata or {}, + preview_id=info.preview_id, + created_at=info.created_at, + last_access_time=info.last_access_time, + created_new=False, + ) + session.commit() + + return result + + +def add_tags_to_asset( + *, + asset_info_id: str, + tags: list[str], + origin: str = "manual", + owner_id: str = "", +) -> schemas_out.TagsAdd: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + data = add_tags_to_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + origin=origin, + create_if_missing=True, + asset_info_row=info_row, + ) + session.commit() + return schemas_out.TagsAdd(**data) + + +def remove_tags_from_asset( + *, + asset_info_id: str, + tags: list[str], + owner_id: str = "", +) -> schemas_out.TagsRemove: + with create_session() as session: + info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) + if not info_row: + raise ValueError(f"AssetInfo {asset_info_id} not found") + if info_row.owner_id and info_row.owner_id != owner_id: + raise PermissionError("not owner") + + data = remove_tags_from_asset_info( + session, + asset_info_id=asset_info_id, + tags=tags, + ) + session.commit() + return schemas_out.TagsRemove(**data) + + def list_tags( prefix: str | None = None, limit: int = 100, diff --git a/app/assets/scanner.py b/app/assets/scanner.py index a16e41d94..0172a5c2f 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -27,6 +27,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No t_start = time.perf_counter() created = 0 skipped_existing = 0 + orphans_pruned = 0 paths: list[str] = [] try: existing_paths: set[str] = set() @@ -38,6 +39,11 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No except Exception as e: logging.exception("fast DB scan failed for %s: %s", r, e) + try: + orphans_pruned = _prune_orphaned_assets(roots) + except Exception as e: + logging.exception("orphan pruning failed: %s", e) + if "models" in roots: paths.extend(collect_models_files()) if "input" in roots: @@ -85,15 +91,43 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No finally: if enable_logging: logging.info( - "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, total_seen=%d)", + "Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)", roots, time.perf_counter() - t_start, created, skipped_existing, + orphans_pruned, len(paths), ) +def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int: + """Prune cache states outside configured prefixes, then delete orphaned seed assets.""" + all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)] + if not all_prefixes: + return 0 + + def make_prefix_condition(prefix: str): + base = prefix if prefix.endswith(os.sep) else prefix + os.sep + escaped, esc = escape_like_prefix(base) + return AssetCacheState.file_path.like(escaped + "%", escape=esc) + + matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes]) + + orphan_subq = ( + sqlalchemy.select(Asset.id) + .outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id) + .where(Asset.hash.is_(None), AssetCacheState.id.is_(None)) + ).scalar_subquery() + + with create_session() as sess: + sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix)) + sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq))) + result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq))) + sess.commit() + return result.rowcount + + def _fast_db_consistency_pass( root: RootType, *, diff --git a/tests-unit/assets_test/conftest.py b/tests-unit/assets_test/conftest.py new file mode 100644 index 000000000..0a57dd7b5 --- /dev/null +++ b/tests-unit/assets_test/conftest.py @@ -0,0 +1,271 @@ +import contextlib +import json +import os +import socket +import subprocess +import sys +import tempfile +import time +from pathlib import Path +from typing import Callable, Iterator, Optional + +import pytest +import requests + + +def pytest_addoption(parser: pytest.Parser) -> None: + """ + Allow overriding the database URL used by the spawned ComfyUI process. + Priority: + 1) --db-url command line option + 2) ASSETS_TEST_DB_URL environment variable (used by CI) + 3) default: None (will use file-backed sqlite in temp dir) + """ + parser.addoption( + "--db-url", + action="store", + default=os.environ.get("ASSETS_TEST_DB_URL"), + help="SQLAlchemy DB URL (e.g. sqlite:///path/to/db.sqlite3)", + ) + + +def _free_port() -> int: + with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s: + s.bind(("127.0.0.1", 0)) + return s.getsockname()[1] + + +def _make_base_dirs(root: Path) -> None: + for sub in ("models", "custom_nodes", "input", "output", "temp", "user"): + (root / sub).mkdir(parents=True, exist_ok=True) + + +def _wait_http_ready(base: str, session: requests.Session, timeout: float = 90.0) -> None: + start = time.time() + last_err = None + while time.time() - start < timeout: + try: + r = session.get(base + "/api/assets", timeout=5) + if r.status_code in (200, 400): + return + except Exception as e: + last_err = e + time.sleep(0.25) + raise RuntimeError(f"ComfyUI HTTP did not become ready: {last_err}") + + +@pytest.fixture(scope="session") +def comfy_tmp_base_dir() -> Path: + env_base = os.environ.get("ASSETS_TEST_BASE_DIR") + created_by_fixture = False + if env_base: + tmp = Path(env_base) + tmp.mkdir(parents=True, exist_ok=True) + else: + tmp = Path(tempfile.mkdtemp(prefix="comfyui-assets-tests-")) + created_by_fixture = True + _make_base_dirs(tmp) + yield tmp + if created_by_fixture: + with contextlib.suppress(Exception): + for p in sorted(tmp.rglob("*"), reverse=True): + if p.is_file() or p.is_symlink(): + p.unlink(missing_ok=True) + for p in sorted(tmp.glob("**/*"), reverse=True): + with contextlib.suppress(Exception): + p.rmdir() + tmp.rmdir() + + +@pytest.fixture(scope="session") +def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest): + """ + Boot ComfyUI subprocess with: + - sandbox base dir + - file-backed sqlite DB in temp dir + - autoscan disabled + Returns (base_url, process, port) + """ + port = _free_port() + db_url = request.config.getoption("--db-url") + if not db_url: + # Use a file-backed sqlite database in the temp directory + db_path = comfy_tmp_base_dir / "assets-test.sqlite3" + db_url = f"sqlite:///{db_path}" + + logs_dir = comfy_tmp_base_dir / "logs" + logs_dir.mkdir(exist_ok=True) + out_log = open(logs_dir / "stdout.log", "w", buffering=1) + err_log = open(logs_dir / "stderr.log", "w", buffering=1) + + comfy_root = Path(__file__).resolve().parent.parent.parent + if not (comfy_root / "main.py").is_file(): + raise FileNotFoundError(f"main.py not found under {comfy_root}") + + proc = subprocess.Popen( + args=[ + sys.executable, + "main.py", + f"--base-directory={str(comfy_tmp_base_dir)}", + f"--database-url={db_url}", + "--disable-assets-autoscan", + "--listen", + "127.0.0.1", + "--port", + str(port), + "--cpu", + ], + stdout=out_log, + stderr=err_log, + cwd=str(comfy_root), + env={**os.environ}, + ) + + for _ in range(50): + if proc.poll() is not None: + out_log.flush() + err_log.flush() + raise RuntimeError(f"ComfyUI exited early with code {proc.returncode}") + time.sleep(0.1) + + base_url = f"http://127.0.0.1:{port}" + try: + with requests.Session() as s: + _wait_http_ready(base_url, s, timeout=90.0) + yield base_url, proc, port + except Exception as e: + with contextlib.suppress(Exception): + proc.terminate() + proc.wait(timeout=10) + with contextlib.suppress(Exception): + out_log.flush() + err_log.flush() + raise RuntimeError(f"ComfyUI did not become ready: {e}") + + if proc and proc.poll() is None: + with contextlib.suppress(Exception): + proc.terminate() + proc.wait(timeout=15) + out_log.close() + err_log.close() + + +@pytest.fixture +def http() -> Iterator[requests.Session]: + with requests.Session() as s: + s.timeout = 120 + yield s + + +@pytest.fixture +def api_base(comfy_url_and_proc) -> str: + base_url, _proc, _port = comfy_url_and_proc + return base_url + + +def _post_multipart_asset( + session: requests.Session, + base: str, + *, + name: str, + tags: list[str], + meta: dict, + data: bytes, + extra_fields: Optional[dict] = None, +) -> tuple[int, dict]: + files = {"file": (name, data, "application/octet-stream")} + form_data = { + "tags": json.dumps(tags), + "name": name, + "user_metadata": json.dumps(meta), + } + if extra_fields: + for k, v in extra_fields.items(): + form_data[k] = v + r = session.post(base + "/api/assets", files=files, data=form_data, timeout=120) + return r.status_code, r.json() + + +@pytest.fixture +def make_asset_bytes() -> Callable[[str, int], bytes]: + def _make(name: str, size: int = 8192) -> bytes: + seed = sum(ord(c) for c in name) % 251 + return bytes((i * 31 + seed) % 256 for i in range(size)) + return _make + + +@pytest.fixture +def asset_factory(http: requests.Session, api_base: str): + """ + Returns create(name, tags, meta, data) -> response dict + Tracks created ids and deletes them after the test. + """ + created: list[str] = [] + + def create(name: str, tags: list[str], meta: dict, data: bytes) -> dict: + status, body = _post_multipart_asset(http, api_base, name=name, tags=tags, meta=meta, data=data) + assert status in (200, 201), body + created.append(body["id"]) + return body + + yield create + + for aid in created: + with contextlib.suppress(Exception): + http.delete(f"{api_base}/api/assets/{aid}", timeout=30) + + +@pytest.fixture +def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_base: str) -> dict: + """ + Upload one asset with ".safetensors" extension into models/checkpoints/unit-tests/. + Returns response dict with id, asset_hash, tags, etc. + """ + name = "unit_1_example.safetensors" + p = getattr(request, "param", {}) or {} + tags: Optional[list[str]] = p.get("tags") + if tags is None: + tags = ["models", "checkpoints", "unit-tests", "alpha"] + meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None} + files = {"file": (name, b"A" * 4096, "application/octet-stream")} + form_data = { + "tags": json.dumps(tags), + "name": name, + "user_metadata": json.dumps(meta), + } + r = http.post(api_base + "/api/assets", files=files, data=form_data, timeout=120) + body = r.json() + assert r.status_code == 201, body + return body + + +@pytest.fixture(autouse=True) +def autoclean_unit_test_assets(http: requests.Session, api_base: str): + """Ensure isolation by removing all AssetInfo rows tagged with 'unit-tests' after each test.""" + yield + + while True: + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests", "limit": "500", "sort": "name"}, + timeout=30, + ) + if r.status_code != 200: + break + body = r.json() + ids = [a["id"] for a in body.get("assets", [])] + if not ids: + break + for aid in ids: + with contextlib.suppress(Exception): + http.delete(f"{api_base}/api/assets/{aid}", timeout=30) + + +def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None: + """Force a fast sync/seed pass by calling the seed endpoint.""" + session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30) + time.sleep(0.2) + + +def get_asset_filename(asset_hash: str, extension: str) -> str: + return asset_hash.removeprefix("blake3:") + extension diff --git a/tests-unit/assets_test/test_assets_missing_sync.py b/tests-unit/assets_test/test_assets_missing_sync.py new file mode 100644 index 000000000..78fa7b404 --- /dev/null +++ b/tests-unit/assets_test/test_assets_missing_sync.py @@ -0,0 +1,348 @@ +import os +import uuid +from pathlib import Path + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_seed_asset_removed_when_file_is_deleted( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, +): + """Asset without hash (seed) whose file disappears: + after triggering sync_seed_assets, Asset + AssetInfo disappear. + """ + # Create a file directly under input/unit-tests/ so tags include "unit-tests" + case_dir = comfy_tmp_base_dir / root / "unit-tests" / "syncseed" + case_dir.mkdir(parents=True, exist_ok=True) + name = f"seed_{uuid.uuid4().hex[:8]}.bin" + fp = case_dir / name + fp.write_bytes(b"Z" * 2048) + + # Trigger a seed sync so DB sees this path (seed asset => hash is NULL) + trigger_sync_seed_assets(http, api_base) + + # Verify it is visible via API and carries no hash (seed) + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + timeout=120, + ) + body1 = r1.json() + assert r1.status_code == 200 + # there should be exactly one with that name + matches = [a for a in body1.get("assets", []) if a.get("name") == name] + assert matches + assert matches[0].get("asset_hash") is None + asset_info_id = matches[0]["id"] + + # Remove the underlying file and sync again + if fp.exists(): + fp.unlink() + + trigger_sync_seed_assets(http, api_base) + + # It should disappear (AssetInfo and seed Asset gone) + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,syncseed", "name_contains": name}, + timeout=120, + ) + body2 = r2.json() + assert r2.status_code == 200 + matches2 = [a for a in body2.get("assets", []) if a.get("name") == name] + assert not matches2, f"Seed asset {asset_info_id} should be gone after sync" + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to verify and clear missing tags") +def test_hashed_asset_missing_tag_added_then_removed_after_scan( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, +): + """Hashed asset with a single cache_state: + 1. delete its file -> sync adds 'missing' + 2. restore file -> sync removes 'missing' + """ + name = "missing_tag_test.png" + tags = ["input", "unit-tests", "msync2"] + data = make_asset_bytes(name, 4096) + a = asset_factory(name, tags, {}, data) + + # Compute its on-disk path and remove it + dest = comfy_tmp_base_dir / "input" / "unit-tests" / "msync2" / get_asset_filename(a["asset_hash"], ".png") + assert dest.exists(), f"Expected asset file at {dest}" + dest.unlink() + + # Fast sync should add 'missing' to the AssetInfo + trigger_sync_seed_assets(http, api_base) + + g1 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" in set(d1.get("tags", [])), "Expected 'missing' tag after deletion" + + # Restore the file with the exact same content and sync again + dest.parent.mkdir(parents=True, exist_ok=True) + dest.write_bytes(data) + + trigger_sync_seed_assets(http, api_base) + + g2 = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + assert "missing" not in set(d2.get("tags", [])), "Missing tag should be cleared after verify" + + +def test_hashed_asset_two_asset_infos_both_get_missing( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, +): + """Hashed asset with a single cache_state, but two AssetInfo rows: + deleting the single file then syncing should add 'missing' to both infos. + """ + # Upload one hashed asset + name = "two_infos_one_path.png" + base_tags = ["input", "unit-tests", "multiinfo"] + created = asset_factory(name, base_tags, {}, b"A" * 2048) + + # Create second AssetInfo for the same Asset via from-hash + payload = { + "hash": created["asset_hash"], + "name": "two_infos_one_path_copy.png", + "tags": base_tags, # keep it in our unit-tests scope for cleanup + "user_metadata": {"k": "v"}, + } + r2 = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120) + b2 = r2.json() + assert r2.status_code == 201, b2 + second_id = b2["id"] + + # Remove the single underlying file + p = comfy_tmp_base_dir / "input" / "unit-tests" / "multiinfo" / get_asset_filename(b2["asset_hash"], ".png") + assert p.exists() + p.unlink() + + r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags0 = r0.json() + assert r0.status_code == 200, tags0 + byname0 = {t["name"]: t for t in tags0.get("tags", [])} + old_missing = int(byname0.get("missing", {}).get("count", 0)) + + # Sync -> both AssetInfos for this asset must receive 'missing' + trigger_sync_seed_assets(http, api_base) + + ga = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120) + da = ga.json() + assert ga.status_code == 200, da + assert "missing" in set(da.get("tags", [])) + + gb = http.get(f"{api_base}/api/assets/{second_id}", timeout=120) + db = gb.json() + assert gb.status_code == 200, db + assert "missing" in set(db.get("tags", [])) + + # Tag usage for 'missing' increased by exactly 2 (two AssetInfos) + r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags1 = r1.json() + assert r1.status_code == 200, tags1 + byname1 = {t["name"]: t for t in tags1.get("tags", [])} + new_missing = int(byname1.get("missing", {}).get("count", 0)) + assert new_missing == old_missing + 2 + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +def test_hashed_asset_two_cache_states_partial_delete_then_full_delete( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Hashed asset with two cache_state rows: + 1. delete one file -> sync should NOT add 'missing' + 2. delete second file -> sync should add 'missing' + """ + name = "two_cache_states_partial_delete.png" + tags = ["input", "unit-tests", "dual"] + data = make_asset_bytes(name, 3072) + + created = asset_factory(name, tags, {}, data) + path1 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual" / get_asset_filename(created["asset_hash"], ".png") + assert path1.exists() + + # Create a second on-disk copy under the same root but different subfolder + path2 = comfy_tmp_base_dir / "input" / "unit-tests" / "dual_copy" / name + path2.parent.mkdir(parents=True, exist_ok=True) + path2.write_bytes(data) + + # Fast seed so the second path appears (as a seed initially) + trigger_sync_seed_assets(http, api_base) + + # Deduplication of AssetInfo-s will not happen as first AssetInfo has owner='default' and second has empty owner. + run_scan_and_wait("input") + + # Remove only one file and sync -> asset should still be healthy (no 'missing') + path1.unlink() + trigger_sync_seed_assets(http, api_base) + + g1 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" not in set(d1.get("tags", [])), "Should not be missing while one valid path remains" + + # Baseline 'missing' usage count just before last file removal + r0 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags0 = r0.json() + assert r0.status_code == 200, tags0 + old_missing = int({t["name"]: t for t in tags0.get("tags", [])}.get("missing", {}).get("count", 0)) + + # Remove the second (last) file and sync -> now we expect 'missing' on this AssetInfo + path2.unlink() + trigger_sync_seed_assets(http, api_base) + + g2 = http.get(f"{api_base}/api/assets/{created['id']}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + assert "missing" in set(d2.get("tags", [])), "Missing must be set once no valid paths remain" + + # Tag usage for 'missing' increased by exactly 2 (two AssetInfo for one Asset) + r1 = http.get(api_base + "/api/tags", params={"limit": "1000", "include_zero": "false"}, timeout=120) + tags1 = r1.json() + assert r1.status_code == 200, tags1 + new_missing = int({t["name"]: t for t in tags1.get("tags", [])}.get("missing", {}).get("count", 0)) + assert new_missing == old_missing + 2 + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_missing_tag_clears_on_fastpass_when_mtime_and_size_match( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, +): + """ + Fast pass alone clears 'missing' when size and mtime match exactly: + 1) upload (hashed), record original mtime_ns + 2) delete -> fast pass adds 'missing' + 3) restore same bytes and set mtime back to the original value + 4) run fast pass again -> 'missing' is removed (no slow scan) + """ + scope = f"fastclear-{uuid.uuid4().hex[:6]}" + name = "fastpass_clear.bin" + data = make_asset_bytes(name, 3072) + + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + base = comfy_tmp_base_dir / root / "unit-tests" / scope + p = base / get_asset_filename(a["asset_hash"], ".bin") + st0 = p.stat() + orig_mtime_ns = getattr(st0, "st_mtime_ns", int(st0.st_mtime * 1_000_000_000)) + + # Delete -> fast pass adds 'missing' + p.unlink() + trigger_sync_seed_assets(http, api_base) + g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" in set(d1.get("tags", [])) + + # Restore same bytes and revert mtime to the original value + p.parent.mkdir(parents=True, exist_ok=True) + p.write_bytes(data) + # set both atime and mtime in ns to ensure exact match + os.utime(p, ns=(orig_mtime_ns, orig_mtime_ns)) + + # Fast pass should clear 'missing' without a scan + trigger_sync_seed_assets(http, api_base) + g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + assert "missing" not in set(d2.get("tags", [])), "Fast pass should clear 'missing' when size+mtime match" + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_fastpass_removes_stale_state_row_no_missing( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Hashed asset with two states: + - delete one file + - run fast pass only + Expect: + - asset stays healthy (no 'missing') + - stale AssetCacheState row for the deleted path is removed. + We verify this behaviorally by recreating the deleted path and running fast pass again: + a new *seed* AssetInfo is created, which proves the old state row was not reused. + """ + scope = f"stale-{uuid.uuid4().hex[:6]}" + name = "two_states.bin" + data = make_asset_bytes(name, 2048) + + # Upload hashed asset at path1 + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + base = comfy_tmp_base_dir / root / "unit-tests" / scope + a1_filename = get_asset_filename(a["asset_hash"], ".bin") + p1 = base / a1_filename + assert p1.exists() + + aid = a["id"] + h = a["asset_hash"] + + # Create second state path2, seed+scan to dedupe into the same Asset + p2 = base / "copy" / name + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + run_scan_and_wait(root) + + # Delete path1 and run fast pass -> no 'missing' and stale state row should be removed + p1.unlink() + trigger_sync_seed_assets(http, api_base) + g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + assert "missing" not in set(d1.get("tags", [])) + + # Recreate path1 and run fast pass again. + # If the stale state row was removed, a NEW seed AssetInfo will appear for this path. + p1.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + + rl = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}"}, + timeout=120, + ) + bl = rl.json() + assert rl.status_code == 200, bl + items = bl.get("assets", []) + # one hashed AssetInfo (asset_hash == h) + one seed AssetInfo (asset_hash == null) + hashes = [it.get("asset_hash") for it in items if it.get("name") in (name, a1_filename)] + assert h in hashes + assert any(x is None for x in hashes), "Expected a new seed AssetInfo for the recreated path" + + # Asset identity still healthy + rh = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120) + assert rh.status_code == 200 diff --git a/tests-unit/assets_test/test_crud.py b/tests-unit/assets_test/test_crud.py new file mode 100644 index 000000000..d2b69f475 --- /dev/null +++ b/tests-unit/assets_test/test_crud.py @@ -0,0 +1,306 @@ +import uuid +from concurrent.futures import ThreadPoolExecutor +from pathlib import Path + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + +def test_create_from_hash_success( + http: requests.Session, api_base: str, seeded_asset: dict +): + h = seeded_asset["asset_hash"] + payload = { + "hash": h, + "name": "from_hash_ok.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "from-hash"], + "user_metadata": {"k": "v"}, + } + r1 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + b1 = r1.json() + assert r1.status_code == 201, b1 + assert b1["asset_hash"] == h + assert b1["created_new"] is False + aid = b1["id"] + + # Calling again with the same name should return the same AssetInfo id + r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + b2 = r2.json() + assert r2.status_code == 201, b2 + assert b2["id"] == aid + + +def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # GET detail + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = rg.json() + assert rg.status_code == 200, detail + assert detail["id"] == aid + assert "user_metadata" in detail + assert "filename" in detail["user_metadata"] + + # DELETE + rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + assert rd.status_code == 204 + + # GET again -> 404 + rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + body = rg2.json() + assert rg2.status_code == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +def test_delete_upon_reference_count( + http: requests.Session, api_base: str, seeded_asset: dict +): + # Create a second reference to the same asset via from-hash + src_hash = seeded_asset["asset_hash"] + payload = { + "hash": src_hash, + "name": "unit_ref_copy.safetensors", + "tags": ["models", "checkpoints", "unit-tests", "del-flow"], + "user_metadata": {"note": "copy"}, + } + r2 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + copy = r2.json() + assert r2.status_code == 201, copy + assert copy["asset_hash"] == src_hash + assert copy["created_new"] is False + + # Delete original reference -> asset identity must remain + aid1 = seeded_asset["id"] + rd1 = http.delete(f"{api_base}/api/assets/{aid1}", timeout=120) + assert rd1.status_code == 204 + + rh1 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) + assert rh1.status_code == 200 # identity still present + + # Delete the last reference with default semantics -> identity and cached files removed + aid2 = copy["id"] + rd2 = http.delete(f"{api_base}/api/assets/{aid2}", timeout=120) + assert rd2.status_code == 204 + + rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120) + assert rh2.status_code == 404 # orphan content removed + + +def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + original_tags = seeded_asset["tags"] + + payload = { + "name": "unit_1_renamed.safetensors", + "user_metadata": {"purpose": "updated", "epoch": 2}, + } + ru = http.put(f"{api_base}/api/assets/{aid}", json=payload, timeout=120) + body = ru.json() + assert ru.status_code == 200, body + assert body["name"] == payload["name"] + assert body["tags"] == original_tags # tags unchanged + assert body["user_metadata"]["purpose"] == "updated" + # filename should still be present and normalized by server + assert "filename" in body["user_metadata"] + + +def test_head_asset_by_hash(http: requests.Session, api_base: str, seeded_asset: dict): + h = seeded_asset["asset_hash"] + + # Existing + rh1 = http.head(f"{api_base}/api/assets/hash/{h}", timeout=120) + assert rh1.status_code == 200 + + # Non-existent + rh2 = http.head(f"{api_base}/api/assets/hash/blake3:{'0'*64}", timeout=120) + assert rh2.status_code == 404 + + +def test_head_asset_bad_hash_returns_400_and_no_body(http: requests.Session, api_base: str): + # Invalid format; handler returns a JSON error, but HEAD responses must not carry a payload. + # requests exposes an empty body for HEAD, so validate status and that there is no payload. + rh = http.head(f"{api_base}/api/assets/hash/not_a_hash", timeout=120) + assert rh.status_code == 400 + body = rh.content + assert body == b"" + + +def test_delete_nonexistent_returns_404(http: requests.Session, api_base: str): + bogus = str(uuid.uuid4()) + r = http.delete(f"{api_base}/api/assets/{bogus}", timeout=120) + body = r.json() + assert r.status_code == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +def test_create_from_hash_invalids(http: requests.Session, api_base: str): + # Bad hash algorithm + bad = { + "hash": "sha256:" + "0" * 64, + "name": "x.bin", + "tags": ["models", "checkpoints", "unit-tests"], + } + r1 = http.post(f"{api_base}/api/assets/from-hash", json=bad, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_BODY" + + # Invalid JSON body + r2 = http.post(f"{api_base}/api/assets/from-hash", data=b"{not json}", timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_JSON" + + +def test_get_update_download_bad_ids(http: requests.Session, api_base: str): + # All endpoints should be not found, as we UUID regex directly in the route definition. + bad_id = "not-a-uuid" + + r1 = http.get(f"{api_base}/api/assets/{bad_id}", timeout=120) + assert r1.status_code == 404 + + r3 = http.get(f"{api_base}/api/assets/{bad_id}/content", timeout=120) + assert r3.status_code == 404 + + +def test_update_requires_at_least_one_field(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + r = http.put(f"{api_base}/api/assets/{aid}", json={}, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_concurrent_delete_same_asset_info_single_204( + root: str, + http: requests.Session, + api_base: str, + asset_factory, + make_asset_bytes, +): + """ + Many concurrent DELETE for the same AssetInfo should result in: + - exactly one 204 No Content (the one that actually deleted) + - all others 404 Not Found (row already gone) + """ + scope = f"conc-del-{uuid.uuid4().hex[:6]}" + name = "to_delete.bin" + data = make_asset_bytes(name, 1536) + + created = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = created["id"] + + # Hit the same endpoint N times in parallel. + n_tests = 4 + url = f"{api_base}/api/assets/{aid}?delete_content=false" + + def _do_delete(delete_url): + with requests.Session() as s: + return s.delete(delete_url, timeout=120).status_code + + with ThreadPoolExecutor(max_workers=n_tests) as ex: + statuses = list(ex.map(_do_delete, [url] * n_tests)) + + # Exactly one actual delete, the rest must be 404 + assert statuses.count(204) == 1, f"Expected exactly one 204; got: {statuses}" + assert statuses.count(404) == n_tests - 1, f"Expected {n_tests-1} 404; got: {statuses}" + + # The resource must be gone. + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + assert rg.status_code == 404 + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_metadata_filename_is_set_for_seed_asset_without_hash( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, +): + """Seed ingest (no hash yet) must compute user_metadata['filename'] immediately.""" + scope = f"seedmeta-{uuid.uuid4().hex[:6]}" + name = "seed_filename.bin" + + base = comfy_tmp_base_dir / root / "unit-tests" / scope / "a" / "b" + base.mkdir(parents=True, exist_ok=True) + fp = base / name + fp.write_bytes(b"Z" * 2048) + + trigger_sync_seed_assets(http, api_base) + + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": name}, + timeout=120, + ) + body = r1.json() + assert r1.status_code == 200, body + matches = [a for a in body.get("assets", []) if a.get("name") == name] + assert matches, "Seed asset should be visible after sync" + assert matches[0].get("asset_hash") is None # still a seed + aid = matches[0]["id"] + + r2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = r2.json() + assert r2.status_code == 200, detail + filename = (detail.get("user_metadata") or {}).get("filename") + expected = str(fp.relative_to(comfy_tmp_base_dir / root)).replace("\\", "/") + assert filename == expected, f"expected filename={expected}, got {filename!r}" + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to retarget cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_metadata_filename_computed_and_updated_on_retarget( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + 1) Ingest under {root}/unit-tests//a/b/ -> filename reflects relative path. + 2) Retarget by copying to {root}/unit-tests//x/, remove old file, + run fast pass + scan -> filename updates to new relative path. + """ + scope = f"meta-fn-{uuid.uuid4().hex[:6]}" + name1 = "compute_metadata_filename.png" + name2 = "compute_changed_metadata_filename.png" + data = make_asset_bytes(name1, 2100) + + # Upload into nested path a/b + a = asset_factory(name1, [root, "unit-tests", scope, "a", "b"], {}, data) + aid = a["id"] + + root_base = comfy_tmp_base_dir / root + p1 = (root_base / "unit-tests" / scope / "a" / "b" / get_asset_filename(a["asset_hash"], ".png")) + assert p1.exists() + + # filename at ingest should be the path relative to root + rel1 = str(p1.relative_to(root_base)).replace("\\", "/") + g1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = g1.json() + assert g1.status_code == 200, d1 + fn1 = d1["user_metadata"].get("filename") + assert fn1 == rel1 + + # Retarget: copy to x/, remove old, then sync+scan + p2 = root_base / "unit-tests" / scope / "x" / name2 + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + if p1.exists(): + p1.unlink() + + trigger_sync_seed_assets(http, api_base) # seed the new path + run_scan_and_wait(root) # verify/hash and reconcile + + # filename should now point at x/ + rel2 = str(p2.relative_to(root_base)).replace("\\", "/") + g2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d2 = g2.json() + assert g2.status_code == 200, d2 + fn2 = d2["user_metadata"].get("filename") + assert fn2 == rel2 diff --git a/tests-unit/assets_test/test_downloads.py b/tests-unit/assets_test/test_downloads.py new file mode 100644 index 000000000..cdebf9082 --- /dev/null +++ b/tests-unit/assets_test/test_downloads.py @@ -0,0 +1,166 @@ +import time +import uuid +from datetime import datetime +from pathlib import Path +from typing import Optional + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + +def test_download_attachment_and_inline(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # default attachment + r1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + data = r1.content + assert r1.status_code == 200 + cd = r1.headers.get("Content-Disposition", "") + assert "attachment" in cd + assert data and len(data) == 4096 + + # inline requested + r2 = http.get(f"{api_base}/api/assets/{aid}/content?disposition=inline", timeout=120) + r2.content + assert r2.status_code == 200 + cd2 = r2.headers.get("Content-Disposition", "") + assert "inline" in cd2 + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_download_chooses_existing_state_and_updates_access_time( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """ + Hashed asset with two state paths: if the first one disappears, + GET /content still serves from the remaining path and bumps last_access_time. + """ + scope = f"dl-first-{uuid.uuid4().hex[:6]}" + name = "first_existing_state.bin" + data = make_asset_bytes(name, 3072) + + # Upload -> path1 + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + + base = comfy_tmp_base_dir / root / "unit-tests" / scope + path1 = base / get_asset_filename(a["asset_hash"], ".bin") + assert path1.exists() + + # Seed path2 by copying, then scan to dedupe into a second state + path2 = base / "alt" / name + path2.parent.mkdir(parents=True, exist_ok=True) + path2.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + run_scan_and_wait(root) + + # Remove path1 so server must fall back to path2 + path1.unlink() + + # last_access_time before + rg0 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d0 = rg0.json() + assert rg0.status_code == 200, d0 + ts0 = d0.get("last_access_time") + + time.sleep(0.05) + r = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + blob = r.content + assert r.status_code == 200 + assert blob == data # must serve from the surviving state (same bytes) + + rg1 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + d1 = rg1.json() + assert rg1.status_code == 200, d1 + ts1 = d1.get("last_access_time") + + def _parse_iso8601(s: Optional[str]) -> Optional[float]: + if not s: + return None + s = s[:-1] if s.endswith("Z") else s + return datetime.fromisoformat(s).timestamp() + + t0 = _parse_iso8601(ts0) + t1 = _parse_iso8601(ts1) + assert t1 is not None + if t0 is not None: + assert t1 > t0 + + +@pytest.mark.parametrize("seeded_asset", [{"tags": ["models", "checkpoints"]}], indirect=True) +def test_download_missing_file_returns_404( + http: requests.Session, api_base: str, comfy_tmp_base_dir: Path, seeded_asset: dict +): + # Remove the underlying file then attempt download. + # We initialize fixture without additional tags to know exactly the asset file path. + try: + aid = seeded_asset["id"] + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = rg.json() + assert rg.status_code == 200 + asset_filename = get_asset_filename(detail["asset_hash"], ".safetensors") + abs_path = comfy_tmp_base_dir / "models" / "checkpoints" / asset_filename + assert abs_path.exists() + abs_path.unlink() + + r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + assert r2.status_code == 404 + body = r2.json() + assert body["error"]["code"] == "FILE_NOT_FOUND" + finally: + # We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually. + dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120) + dr.content + + +@pytest.mark.skip(reason="Requires computing hashes of files in directories to deduplicate into multiple cache states") +@pytest.mark.parametrize("root", ["input", "output"]) +def test_download_404_if_all_states_missing( + root: str, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, + run_scan_and_wait, +): + """Multi-state asset: after the last remaining on-disk file is removed, download must return 404.""" + scope = f"dl-404-{uuid.uuid4().hex[:6]}" + name = "missing_all_states.bin" + data = make_asset_bytes(name, 2048) + + # Upload -> path1 + a = asset_factory(name, [root, "unit-tests", scope], {}, data) + aid = a["id"] + + base = comfy_tmp_base_dir / root / "unit-tests" / scope + p1 = base / get_asset_filename(a["asset_hash"], ".bin") + assert p1.exists() + + # Seed a second state and dedupe + p2 = base / "copy" / name + p2.parent.mkdir(parents=True, exist_ok=True) + p2.write_bytes(data) + trigger_sync_seed_assets(http, api_base) + run_scan_and_wait(root) + + # Remove first file -> download should still work via the second state + p1.unlink() + ok1 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + b1 = ok1.content + assert ok1.status_code == 200 and b1 == data + + # Remove the last file -> download must 404 + p2.unlink() + r2 = http.get(f"{api_base}/api/assets/{aid}/content", timeout=120) + body = r2.json() + assert r2.status_code == 404 + assert body["error"]["code"] == "FILE_NOT_FOUND" diff --git a/tests-unit/assets_test/test_list_filter.py b/tests-unit/assets_test/test_list_filter.py new file mode 100644 index 000000000..82e109832 --- /dev/null +++ b/tests-unit/assets_test/test_list_filter.py @@ -0,0 +1,342 @@ +import time +import uuid + +import requests + + +def test_list_assets_paging_and_sort(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + names = ["a1_u.safetensors", "a2_u.safetensors", "a3_u.safetensors"] + for n in names: + asset_factory( + n, + ["models", "checkpoints", "unit-tests", "paging"], + {"epoch": 1}, + make_asset_bytes(n, size=2048), + ) + + # name ascending for stable order + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "0"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + got1 = [a["name"] for a in b1["assets"]] + assert got1 == sorted(names)[:2] + assert b1["has_more"] is True + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,paging", "sort": "name", "order": "asc", "limit": "2", "offset": "2"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + got2 = [a["name"] for a in b2["assets"]] + assert got2 == sorted(names)[2:] + assert b2["has_more"] is False + + +def test_list_assets_include_exclude_and_name_contains(http: requests.Session, api_base: str, asset_factory): + a = asset_factory("inc_a.safetensors", ["models", "checkpoints", "unit-tests", "alpha"], {}, b"X" * 1024) + b = asset_factory("inc_b.safetensors", ["models", "checkpoints", "unit-tests", "beta"], {}, b"Y" * 1024) + + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,alpha", "exclude_tags": "beta", "limit": "50"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert a["name"] in names + assert b["name"] not in names + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests", "name_contains": "inc_"}, + timeout=120, + ) + body2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in body2["assets"]] + assert a["name"] in names2 + assert b["name"] in names2 + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "non-existing-tag"}, + timeout=120, + ) + body3 = r2.json() + assert r2.status_code == 200 + assert not body3["assets"] + + +def test_list_assets_sort_by_size_both_orders(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-size"] + n1, n2, n3 = "sz1.safetensors", "sz2.safetensors", "sz3.safetensors" + asset_factory(n1, t, {}, make_asset_bytes(n1, 1024)) + asset_factory(n2, t, {}, make_asset_bytes(n2, 2048)) + asset_factory(n3, t, {}, make_asset_bytes(n3, 3072)) + + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "asc"}, + timeout=120, + ) + b1 = r1.json() + names = [a["name"] for a in b1["assets"]] + assert names[:3] == [n1, n2, n3] + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-size", "sort": "size", "order": "desc"}, + timeout=120, + ) + b2 = r2.json() + names2 = [a["name"] for a in b2["assets"]] + assert names2[:3] == [n3, n2, n1] + + + +def test_list_assets_sort_by_updated_at_desc(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-upd"] + a1 = asset_factory("upd_a.safetensors", t, {}, make_asset_bytes("upd_a", 1200)) + a2 = asset_factory("upd_b.safetensors", t, {}, make_asset_bytes("upd_b", 1200)) + + # Rename the second asset to bump updated_at + rp = http.put(f"{api_base}/api/assets/{a2['id']}", json={"name": "upd_b_renamed.safetensors"}, timeout=120) + upd = rp.json() + assert rp.status_code == 200, upd + + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-upd", "sort": "updated_at", "order": "desc"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert names[0] == "upd_b_renamed.safetensors" + assert a1["name"] in names + + + +def test_list_assets_sort_by_last_access_time_desc(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-access"] + asset_factory("acc_a.safetensors", t, {}, make_asset_bytes("acc_a", 1100)) + time.sleep(0.02) + a2 = asset_factory("acc_b.safetensors", t, {}, make_asset_bytes("acc_b", 1100)) + + # Touch last_access_time of b by downloading its content + time.sleep(0.02) + dl = http.get(f"{api_base}/api/assets/{a2['id']}/content", timeout=120) + assert dl.status_code == 200 + dl.content + + r = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-access", "sort": "last_access_time", "order": "desc"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert names[0] == a2["name"] + + +def test_list_assets_include_tags_variants_and_case(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-include"] + a = asset_factory("incvar_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("iva")) + asset_factory("incvar_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("ivb")) + + # CSV + case-insensitive + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "UNIT-TESTS,LF-INCLUDE,alpha"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a["name"] in names1 + assert not any("beta" in x for x in names1) + + # Repeated query params for include_tags + params_multi = [ + ("include_tags", "unit-tests"), + ("include_tags", "lf-include"), + ("include_tags", "alpha"), + ] + r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in b2["assets"]] + assert a["name"] in names2 + assert not any("beta" in x for x in names2) + + # Duplicates and spaces in CSV + r3 = http.get( + api_base + "/api/assets", + params={"include_tags": " unit-tests , lf-include , alpha , alpha "}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 + names3 = [x["name"] for x in b3["assets"]] + assert a["name"] in names3 + + +def test_list_assets_exclude_tags_dedup_and_case(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-exclude"] + a = asset_factory("ex_a_alpha.safetensors", [*t, "alpha"], {}, make_asset_bytes("exa", 900)) + asset_factory("ex_b_beta.safetensors", [*t, "beta"], {}, make_asset_bytes("exb", 900)) + + # Exclude uppercase should work + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-exclude", "exclude_tags": "BETA"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a["name"] in names1 + # Repeated excludes with duplicates + params_multi = [ + ("include_tags", "unit-tests"), + ("include_tags", "lf-exclude"), + ("exclude_tags", "beta"), + ("exclude_tags", "beta"), + ] + r2 = http.get(api_base + "/api/assets", params=params_multi, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in b2["assets"]] + assert all("beta" not in x for x in names2) + + +def test_list_assets_name_contains_case_and_specials(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-name"] + a1 = asset_factory("CaseMix.SAFE", t, {}, make_asset_bytes("cm", 800)) + a2 = asset_factory("case-other.safetensors", t, {}, make_asset_bytes("co", 800)) + + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": "casemix"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [x["name"] for x in b1["assets"]] + assert a1["name"] in names1 + + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": ".SAFE"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + names2 = [x["name"] for x in b2["assets"]] + assert a1["name"] in names2 + + r3 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-name", "name_contains": "case-"}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 + names3 = [x["name"] for x in b3["assets"]] + assert a2["name"] in names3 + + +def test_list_assets_offset_beyond_total_and_limit_boundary(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "lf-pagelimits"] + asset_factory("pl1.safetensors", t, {}, make_asset_bytes("pl1", 600)) + asset_factory("pl2.safetensors", t, {}, make_asset_bytes("pl2", 600)) + asset_factory("pl3.safetensors", t, {}, make_asset_bytes("pl3", 600)) + + # Offset far beyond total + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-pagelimits", "limit": "2", "offset": "10"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + assert not b1["assets"] + assert b1["has_more"] is False + + # Boundary large limit (<=500 is valid) + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,lf-pagelimits", "limit": "500"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + assert len(b2["assets"]) == 3 + assert b2["has_more"] is False + + +def test_list_assets_offset_negative_and_limit_nonint_rejected(http, api_base): + r1 = http.get(api_base + "/api/assets", params={"offset": "-1"}, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_QUERY" + + r2 = http.get(api_base + "/api/assets", params={"limit": "abc"}, timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_QUERY" + + +def test_list_assets_invalid_query_rejected(http: requests.Session, api_base: str): + # limit too small + r1 = http.get(api_base + "/api/assets", params={"limit": "0"}, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_QUERY" + + # bad metadata JSON + r2 = http.get(api_base + "/api/assets", params={"metadata_filter": "{not json"}, timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_QUERY" + + +def test_list_assets_name_contains_literal_underscore( + http, + api_base, + asset_factory, + make_asset_bytes, +): + """'name_contains' must treat '_' literally, not as a SQL wildcard. + We create: + - foo_bar.safetensors (should match) + - fooxbar.safetensors (must NOT match if '_' is escaped) + - foobar.safetensors (must NOT match) + """ + scope = f"lf-underscore-{uuid.uuid4().hex[:6]}" + tags = ["models", "checkpoints", "unit-tests", scope] + + a = asset_factory("foo_bar.safetensors", tags, {}, make_asset_bytes("a", 700)) + b = asset_factory("fooxbar.safetensors", tags, {}, make_asset_bytes("b", 700)) + c = asset_factory("foobar.safetensors", tags, {}, make_asset_bytes("c", 700)) + + r = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "name_contains": "foo_bar"}, + timeout=120, + ) + body = r.json() + assert r.status_code == 200, body + names = [x["name"] for x in body["assets"]] + assert a["name"] in names, f"Expected literal underscore match to include {a['name']}" + assert b["name"] not in names, "Underscore must be escaped — should not match 'fooxbar'" + assert c["name"] not in names, "Underscore must be escaped — should not match 'foobar'" + assert body["total"] == 1 diff --git a/tests-unit/assets_test/test_metadata_filters.py b/tests-unit/assets_test/test_metadata_filters.py new file mode 100644 index 000000000..20285a3b3 --- /dev/null +++ b/tests-unit/assets_test/test_metadata_filters.py @@ -0,0 +1,395 @@ +import json + + +def test_meta_and_across_keys_and_types( + http, api_base: str, asset_factory, make_asset_bytes +): + name = "mf_and_mix.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-and"] + meta = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + asset_factory(name, tags, meta, make_asset_bytes(name, 4096)) + + # All keys must match (AND semantics) + f_ok = {"purpose": "mix", "epoch": 1, "active": True, "score": 1.23} + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_ok), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names = [a["name"] for a in b1["assets"]] + assert name in names + + # One key mismatched -> no result + f_bad = {"purpose": "mix", "epoch": 2, "active": True} + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-and", + "metadata_filter": json.dumps(f_bad), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + assert not b2["assets"] + + +def test_meta_type_strictness_int_vs_str_and_bool(http, api_base, asset_factory, make_asset_bytes): + name = "mf_types.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-types"] + meta = {"epoch": 1, "active": True} + asset_factory(name, tags, meta, make_asset_bytes(name)) + + # int filter matches numeric + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": 1}), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # string "1" must NOT match numeric 1 + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"epoch": "1"}), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + # bool True matches, string "true" must NOT match + r3 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": True}), + }, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"]) + + r4 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-types", + "metadata_filter": json.dumps({"active": "true"}), + }, + timeout=120, + ) + b4 = r4.json() + assert r4.status_code == 200 and not b4["assets"] + + +def test_meta_any_of_list_of_scalars(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_scalars.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-list"] + meta = {"flags": ["red", "green"]} + asset_factory(name, tags, meta, make_asset_bytes(name, 3000)) + + # Any-of should match because "green" is present + filt_ok = {"flags": ["blue", "green"]} + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_ok)}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # None of provided flags present -> no match + filt_miss = {"flags": ["blue", "yellow"]} + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_miss)}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + # Duplicates in list should not break matching + filt_dup = {"flags": ["green", "green", "green"]} + r3 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-list", "metadata_filter": json.dumps(filt_dup)}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 200 and any(a["name"] == name for a in b3["assets"]) + + +def test_meta_none_semantics_missing_or_null_and_any_of_with_none( + http, api_base, asset_factory, make_asset_bytes +): + # a1: key missing; a2: explicit null; a3: concrete value + t = ["models", "checkpoints", "unit-tests", "mf-none"] + a1 = asset_factory("mf_none_missing.safetensors", t, {"x": 1}, make_asset_bytes("a1")) + a2 = asset_factory("mf_none_null.safetensors", t, {"maybe": None}, make_asset_bytes("a2")) + a3 = asset_factory("mf_none_value.safetensors", t, {"maybe": "x"}, make_asset_bytes("a3")) + + # Filter {maybe: None} must match a1 and a2, not a3 + filt = {"maybe": None} + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt), "sort": "name"}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + got = [a["name"] for a in b1["assets"]] + assert a1["name"] in got and a2["name"] in got and a3["name"] not in got + + # Any-of with None should include missing/null plus value matches + filt_any = {"maybe": [None, "x"]} + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-none", "metadata_filter": json.dumps(filt_any), "sort": "name"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 + got2 = [a["name"] for a in b2["assets"]] + assert a1["name"] in got2 and a2["name"] in got2 and a3["name"] in got2 + + +def test_meta_nested_json_object_equality(http, api_base, asset_factory, make_asset_bytes): + name = "mf_nested_json.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-nested"] + cfg = {"optimizer": "adam", "lr": 0.001, "schedule": {"type": "cosine", "warmup": 100}} + asset_factory(name, tags, {"config": cfg}, make_asset_bytes(name, 2200)) + + # Exact JSON object equality (same structure) + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": cfg}), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Different JSON object should not match + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-nested", + "metadata_filter": json.dumps({"config": {"optimizer": "sgd"}}), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_list_of_objects_any_of(http, api_base, asset_factory, make_asset_bytes): + name = "mf_list_objects.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-objlist"] + transforms = [{"type": "crop", "size": 128}, {"type": "flip", "p": 0.5}] + asset_factory(name, tags, {"transforms": transforms}, make_asset_bytes(name, 2048)) + + # Any-of for list of objects should match when one element equals the filter object + r1 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "flip", "p": 0.5}}), + }, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Non-matching object -> no match + r2 = http.get( + api_base + "/api/assets", + params={ + "include_tags": "unit-tests,mf-objlist", + "metadata_filter": json.dumps({"transforms": {"type": "rotate", "deg": 90}}), + }, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_with_special_and_unicode_keys(http, api_base, asset_factory, make_asset_bytes): + name = "mf_keys_unicode.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-keys"] + meta = { + "weird.key": "v1", + "path/like": 7, + "with:colon": True, + "ключ": "значение", + "emoji": "🐍", + } + asset_factory(name, tags, meta, make_asset_bytes(name, 1500)) + + # Match all the special keys + filt = {"weird.key": "v1", "path/like": 7, "with:colon": True, "emoji": "🐍"} + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps(filt)}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Unicode key match + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-keys", "metadata_filter": json.dumps({"ключ": "значение"})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and any(a["name"] == name for a in b2["assets"]) + + +def test_meta_with_zero_and_boolean_lists(http, api_base, asset_factory, make_asset_bytes): + t = ["models", "checkpoints", "unit-tests", "mf-zero-bool"] + a0 = asset_factory("mf_zero_count.safetensors", t, {"count": 0}, make_asset_bytes("z", 1025)) + a1 = asset_factory("mf_bool_list.safetensors", t, {"choices": [True, False]}, make_asset_bytes("b", 1026)) + + # count == 0 must match only a0 + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"count": 0})}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names1 = [a["name"] for a in b1["assets"]] + assert a0["name"] in names1 and a1["name"] not in names1 + + # Any-of list of booleans: True matches second asset + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-zero-bool", "metadata_filter": json.dumps({"choices": True})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and any(a["name"] == a1["name"] for a in b2["assets"]) + + +def test_meta_mixed_list_types_and_strictness(http, api_base, asset_factory, make_asset_bytes): + name = "mf_mixed_list.safetensors" + tags = ["models", "checkpoints", "unit-tests", "mf-mixed"] + meta = {"mix": ["1", 1, True, None]} + asset_factory(name, tags, meta, make_asset_bytes(name, 1999)) + + # Should match because 1 is present + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": [2, 1]})}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 and any(a["name"] == name for a in b1["assets"]) + + # Should NOT match for False + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-mixed", "metadata_filter": json.dumps({"mix": False})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_unknown_key_and_none_behavior_with_scope_tags(http, api_base, asset_factory, make_asset_bytes): + # Use a unique scope tag to avoid interference + t = ["models", "checkpoints", "unit-tests", "mf-unknown-scope"] + x = asset_factory("mf_unknown_a.safetensors", t, {"k1": 1}, make_asset_bytes("ua")) + y = asset_factory("mf_unknown_b.safetensors", t, {"k2": 2}, make_asset_bytes("ub")) + + # Filtering by unknown key with None should return both (missing key OR null) + r1 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": None})}, + timeout=120, + ) + b1 = r1.json() + assert r1.status_code == 200 + names = {a["name"] for a in b1["assets"]} + assert x["name"] in names and y["name"] in names + + # Filtering by unknown key with concrete value should return none + r2 = http.get( + api_base + "/api/assets", + params={"include_tags": "unit-tests,mf-unknown-scope", "metadata_filter": json.dumps({"unknown": "x"})}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200 and not b2["assets"] + + +def test_meta_with_tags_include_exclude_and_name_contains(http, api_base, asset_factory, make_asset_bytes): + # alpha matches epoch=1; beta has epoch=2 + a = asset_factory( + "mf_tag_alpha.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "alpha"], + {"epoch": 1}, + make_asset_bytes("alpha"), + ) + b = asset_factory( + "mf_tag_beta.safetensors", + ["models", "checkpoints", "unit-tests", "mf-tag", "beta"], + {"epoch": 2}, + make_asset_bytes("beta"), + ) + + params = { + "include_tags": "unit-tests,mf-tag,alpha", + "exclude_tags": "beta", + "name_contains": "mf_tag_", + "metadata_filter": json.dumps({"epoch": 1}), + } + r = http.get(api_base + "/api/assets", params=params, timeout=120) + body = r.json() + assert r.status_code == 200 + names = [x["name"] for x in body["assets"]] + assert a["name"] in names + assert b["name"] not in names + + +def test_meta_sort_and_paging_under_filter(http, api_base, asset_factory, make_asset_bytes): + # Three assets in same scope with different sizes and a common filter key + t = ["models", "checkpoints", "unit-tests", "mf-sort"] + n1, n2, n3 = "mf_sort_1.safetensors", "mf_sort_2.safetensors", "mf_sort_3.safetensors" + asset_factory(n1, t, {"group": "g"}, make_asset_bytes(n1, 1024)) + asset_factory(n2, t, {"group": "g"}, make_asset_bytes(n2, 2048)) + asset_factory(n3, t, {"group": "g"}, make_asset_bytes(n3, 3072)) + + # Sort by size ascending with paging + q = { + "include_tags": "unit-tests,mf-sort", + "metadata_filter": json.dumps({"group": "g"}), + "sort": "size", "order": "asc", "limit": "2", + } + r1 = http.get(api_base + "/api/assets", params=q, timeout=120) + b1 = r1.json() + assert r1.status_code == 200 + got1 = [a["name"] for a in b1["assets"]] + assert got1 == [n1, n2] + assert b1["has_more"] is True + + q2 = {**q, "offset": "2"} + r2 = http.get(api_base + "/api/assets", params=q2, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + got2 = [a["name"] for a in b2["assets"]] + assert got2 == [n3] + assert b2["has_more"] is False diff --git a/tests-unit/assets_test/test_prune_orphaned_assets.py b/tests-unit/assets_test/test_prune_orphaned_assets.py new file mode 100644 index 000000000..f602e5a77 --- /dev/null +++ b/tests-unit/assets_test/test_prune_orphaned_assets.py @@ -0,0 +1,141 @@ +import uuid +from pathlib import Path + +import pytest +import requests +from conftest import get_asset_filename, trigger_sync_seed_assets + + +@pytest.fixture +def create_seed_file(comfy_tmp_base_dir: Path): + """Create a file on disk that will become a seed asset after sync.""" + created: list[Path] = [] + + def _create(root: str, scope: str, name: str | None = None, data: bytes = b"TEST") -> Path: + name = name or f"seed_{uuid.uuid4().hex[:8]}.bin" + path = comfy_tmp_base_dir / root / "unit-tests" / scope / name + path.parent.mkdir(parents=True, exist_ok=True) + path.write_bytes(data) + created.append(path) + return path + + yield _create + + for p in created: + p.unlink(missing_ok=True) + + +@pytest.fixture +def find_asset(http: requests.Session, api_base: str): + """Query API for assets matching scope and optional name.""" + def _find(scope: str, name: str | None = None) -> list[dict]: + params = {"include_tags": f"unit-tests,{scope}"} + if name: + params["name_contains"] = name + r = http.get(f"{api_base}/api/assets", params=params, timeout=120) + assert r.status_code == 200 + assets = r.json().get("assets", []) + if name: + return [a for a in assets if a.get("name") == name] + return assets + + return _find + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_orphaned_seed_asset_is_pruned( + root: str, + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, +): + """Seed asset with deleted file is removed; with file present, it survives.""" + scope = f"prune-{uuid.uuid4().hex[:6]}" + fp = create_seed_file(root, scope) + name = fp.name + + trigger_sync_seed_assets(http, api_base) + assert find_asset(scope, name), "Seed asset should exist" + + fp.unlink() + trigger_sync_seed_assets(http, api_base) + assert not find_asset(scope, name), "Orphaned seed should be pruned" + + +def test_seed_asset_with_file_survives_prune( + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, +): + """Seed asset with file still on disk is NOT pruned.""" + scope = f"keep-{uuid.uuid4().hex[:6]}" + fp = create_seed_file("input", scope) + + trigger_sync_seed_assets(http, api_base) + trigger_sync_seed_assets(http, api_base) + + assert find_asset(scope, fp.name), "Seed with valid file should survive" + + +def test_hashed_asset_not_pruned_when_file_missing( + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, + asset_factory, + make_asset_bytes, +): + """Hashed assets are never deleted by prune, even without file.""" + scope = f"hashed-{uuid.uuid4().hex[:6]}" + data = make_asset_bytes("test", 2048) + a = asset_factory("test.bin", ["input", "unit-tests", scope], {}, data) + + path = comfy_tmp_base_dir / "input" / "unit-tests" / scope / get_asset_filename(a["asset_hash"], ".bin") + path.unlink() + + trigger_sync_seed_assets(http, api_base) + + r = http.get(f"{api_base}/api/assets/{a['id']}", timeout=120) + assert r.status_code == 200, "Hashed asset should NOT be pruned" + + +def test_prune_across_multiple_roots( + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, +): + """Prune correctly handles assets across input and output roots.""" + scope = f"multi-{uuid.uuid4().hex[:6]}" + input_fp = create_seed_file("input", scope, "input.bin") + create_seed_file("output", scope, "output.bin") + + trigger_sync_seed_assets(http, api_base) + assert len(find_asset(scope)) == 2 + + input_fp.unlink() + trigger_sync_seed_assets(http, api_base) + + remaining = find_asset(scope) + assert len(remaining) == 1 + assert remaining[0]["name"] == "output.bin" + + +@pytest.mark.parametrize("dirname", ["100%_done", "my_folder_name", "has spaces"]) +def test_special_chars_in_path_escaped_correctly( + dirname: str, + create_seed_file, + find_asset, + http: requests.Session, + api_base: str, + comfy_tmp_base_dir: Path, +): + """SQL LIKE wildcards (%, _) and spaces in paths don't cause false matches.""" + scope = f"special-{uuid.uuid4().hex[:6]}/{dirname}" + fp = create_seed_file("input", scope) + + trigger_sync_seed_assets(http, api_base) + trigger_sync_seed_assets(http, api_base) + + assert find_asset(scope.split("/")[0], fp.name), "Asset with special chars should survive" diff --git a/tests-unit/assets_test/test_tags.py b/tests-unit/assets_test/test_tags.py new file mode 100644 index 000000000..6b1047802 --- /dev/null +++ b/tests-unit/assets_test/test_tags.py @@ -0,0 +1,225 @@ +import json +import uuid + +import requests + + +def test_tags_present(http: requests.Session, api_base: str, seeded_asset: dict): + # Include zero-usage tags by default + r1 = http.get(api_base + "/api/tags", params={"limit": "50"}, timeout=120) + body1 = r1.json() + assert r1.status_code == 200 + names = [t["name"] for t in body1["tags"]] + # A few system tags from migration should exist: + assert "models" in names + assert "checkpoints" in names + + # Only used tags before we add anything new from this test cycle + r2 = http.get(api_base + "/api/tags", params={"include_zero": "false"}, timeout=120) + body2 = r2.json() + assert r2.status_code == 200 + # We already seeded one asset via fixture, so used tags must be non-empty + used_names = [t["name"] for t in body2["tags"]] + assert "models" in used_names + assert "checkpoints" in used_names + + # Prefix filter should refine the list + r3 = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": "uni"}, timeout=120) + b3 = r3.json() + assert r3.status_code == 200 + names3 = [t["name"] for t in b3["tags"]] + assert "unit-tests" in names3 + assert "models" not in names3 # filtered out by prefix + + # Order by name ascending should be stable + r4 = http.get(api_base + "/api/tags", params={"include_zero": "false", "order": "name_asc"}, timeout=120) + b4 = r4.json() + assert r4.status_code == 200 + names4 = [t["name"] for t in b4["tags"]] + assert names4 == sorted(names4) + + +def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory, make_asset_bytes): + # Baseline: system tags exist when include_zero (default) is true + r1 = http.get(api_base + "/api/tags", params={"limit": "500"}, timeout=120) + body1 = r1.json() + assert r1.status_code == 200 + names = [t["name"] for t in body1["tags"]] + assert "models" in names and "checkpoints" in names + + # Create a short-lived asset under input with a unique custom tag + scope = f"tags-empty-usage-{uuid.uuid4().hex[:6]}" + custom_tag = f"temp-{uuid.uuid4().hex[:8]}" + name = "tag_seed.bin" + _asset = asset_factory( + name, + ["input", "unit-tests", scope, custom_tag], + {}, + make_asset_bytes(name, 512), + ) + + # While the asset exists, the custom tag must appear when include_zero=false + r2 = http.get( + api_base + "/api/tags", + params={"include_zero": "false", "prefix": custom_tag, "limit": "50"}, + timeout=120, + ) + body2 = r2.json() + assert r2.status_code == 200 + used_names = [t["name"] for t in body2["tags"]] + assert custom_tag in used_names + + # Delete the asset so the tag usage drops to zero + rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120) + assert rd.status_code == 204 + + # Now the custom tag must not be returned when include_zero=false + r3 = http.get( + api_base + "/api/tags", + params={"include_zero": "false", "prefix": custom_tag, "limit": "50"}, + timeout=120, + ) + body3 = r3.json() + assert r3.status_code == 200 + names_after = [t["name"] for t in body3["tags"]] + assert custom_tag not in names_after + assert not names_after # filtered view should be empty now + + +def test_add_and_remove_tags(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # Add tags with duplicates and mixed case + payload_add = {"tags": ["NewTag", "unit-tests", "newtag", "BETA"]} + r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json=payload_add, timeout=120) + b1 = r1.json() + assert r1.status_code == 200, b1 + # normalized, deduplicated; 'unit-tests' was already present from the seed + assert set(b1["added"]) == {"newtag", "beta"} + assert set(b1["already_present"]) == {"unit-tests"} + assert "newtag" in b1["total_tags"] and "beta" in b1["total_tags"] + + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + g = rg.json() + assert rg.status_code == 200 + tags_now = set(g["tags"]) + assert {"newtag", "beta"}.issubset(tags_now) + + # Remove a tag and a non-existent tag + payload_del = {"tags": ["newtag", "does-not-exist"]} + r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json=payload_del, timeout=120) + b2 = r2.json() + assert r2.status_code == 200 + assert set(b2["removed"]) == {"newtag"} + assert set(b2["not_present"]) == {"does-not-exist"} + + # Verify remaining tags after deletion + rg2 = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + g2 = rg2.json() + assert rg2.status_code == 200 + tags_later = set(g2["tags"]) + assert "newtag" not in tags_later + assert "beta" in tags_later # still present + + +def test_tags_list_order_and_prefix(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + h = seeded_asset["asset_hash"] + + # Add both tags to the seeded asset (usage: orderaaa=1, orderbbb=1) + r_add = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": ["orderaaa", "orderbbb"]}, timeout=120) + add_body = r_add.json() + assert r_add.status_code == 200, add_body + + # Create another AssetInfo from the same content but tagged ONLY with 'orderbbb'. + payload = { + "hash": h, + "name": "order_only_bbb.safetensors", + "tags": ["input", "unit-tests", "orderbbb"], + "user_metadata": {}, + } + r_copy = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120) + copy_body = r_copy.json() + assert r_copy.status_code == 201, copy_body + + # 1) Default order (count_desc): 'orderbbb' should come before 'orderaaa' + # because it has higher usage (2 vs 1). + r1 = http.get(api_base + "/api/tags", params={"prefix": "order", "include_zero": "false"}, timeout=120) + b1 = r1.json() + assert r1.status_code == 200, b1 + names1 = [t["name"] for t in b1["tags"]] + counts1 = {t["name"]: t["count"] for t in b1["tags"]} + # Both must be present within the prefix subset + assert "orderaaa" in names1 and "orderbbb" in names1 + # Usage of 'orderbbb' must be >= 'orderaaa'; in our setup it's 2 vs 1 + assert counts1["orderbbb"] >= counts1["orderaaa"] + # And with count_desc, 'orderbbb' appears earlier than 'orderaaa' + assert names1.index("orderbbb") < names1.index("orderaaa") + + # 2) name_asc: lexical order should flip the relative order + r2 = http.get( + api_base + "/api/tags", + params={"prefix": "order", "include_zero": "false", "order": "name_asc"}, + timeout=120, + ) + b2 = r2.json() + assert r2.status_code == 200, b2 + names2 = [t["name"] for t in b2["tags"]] + assert "orderaaa" in names2 and "orderbbb" in names2 + assert names2.index("orderaaa") < names2.index("orderbbb") + + # 3) invalid limit rejected (existing negative case retained) + r3 = http.get(api_base + "/api/tags", params={"limit": "1001"}, timeout=120) + b3 = r3.json() + assert r3.status_code == 400 + assert b3["error"]["code"] == "INVALID_QUERY" + + +def test_tags_endpoints_invalid_bodies(http: requests.Session, api_base: str, seeded_asset: dict): + aid = seeded_asset["id"] + + # Add with empty list + r1 = http.post(f"{api_base}/api/assets/{aid}/tags", json={"tags": []}, timeout=120) + b1 = r1.json() + assert r1.status_code == 400 + assert b1["error"]["code"] == "INVALID_BODY" + + # Remove with wrong type + r2 = http.delete(f"{api_base}/api/assets/{aid}/tags", json={"tags": [123]}, timeout=120) + b2 = r2.json() + assert r2.status_code == 400 + assert b2["error"]["code"] == "INVALID_BODY" + + # metadata_filter provided as JSON array should be rejected (must be object) + r3 = http.get( + api_base + "/api/assets", + params={"metadata_filter": json.dumps([{"x": 1}])}, + timeout=120, + ) + b3 = r3.json() + assert r3.status_code == 400 + assert b3["error"]["code"] == "INVALID_QUERY" + + +def test_tags_prefix_treats_underscore_literal( + http, + api_base, + asset_factory, + make_asset_bytes, +): + """'prefix' for /api/tags must treat '_' literally, not as a wildcard.""" + base = f"pref_{uuid.uuid4().hex[:6]}" + tag_ok = f"{base}_ok" # should match prefix=f"{base}_" + tag_bad = f"{base}xok" # must NOT match if '_' is escaped + scope = f"tags-underscore-{uuid.uuid4().hex[:6]}" + + asset_factory("t1.bin", ["input", "unit-tests", scope, tag_ok], {}, make_asset_bytes("t1", 512)) + asset_factory("t2.bin", ["input", "unit-tests", scope, tag_bad], {}, make_asset_bytes("t2", 512)) + + r = http.get(api_base + "/api/tags", params={"include_zero": "false", "prefix": f"{base}_"}, timeout=120) + body = r.json() + assert r.status_code == 200, body + names = [t["name"] for t in body["tags"]] + assert tag_ok in names, f"Expected {tag_ok} to be returned for prefix '{base}_'" + assert tag_bad not in names, f"'{tag_bad}' must not match — '_' is not a wildcard" + assert body["total"] == 1 diff --git a/tests-unit/assets_test/test_uploads.py b/tests-unit/assets_test/test_uploads.py new file mode 100644 index 000000000..137d7391a --- /dev/null +++ b/tests-unit/assets_test/test_uploads.py @@ -0,0 +1,281 @@ +import json +import uuid +from concurrent.futures import ThreadPoolExecutor + +import requests +import pytest + + +def test_upload_ok_duplicate_reference(http: requests.Session, api_base: str, make_asset_bytes): + name = "dup_a.safetensors" + tags = ["models", "checkpoints", "unit-tests", "alpha"] + meta = {"purpose": "dup"} + data = make_asset_bytes(name) + files = {"file": (name, data, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} + r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a1 = r1.json() + assert r1.status_code == 201, a1 + assert a1["created_new"] is True + + # Second upload with the same data and name should return created_new == False and the same asset + files = {"file": (name, data, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} + r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a2 = r2.json() + assert r2.status_code == 200, a2 + assert a2["created_new"] is False + assert a2["asset_hash"] == a1["asset_hash"] + assert a2["id"] == a1["id"] # old reference + + # Third upload with the same data but new name should return created_new == False and the new AssetReference + files = {"file": (name, data, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name + "_d", "user_metadata": json.dumps(meta)} + r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + a3 = r2.json() + assert r2.status_code == 200, a3 + assert a3["created_new"] is False + assert a3["asset_hash"] == a1["asset_hash"] + assert a3["id"] != a1["id"] # old reference + + +def test_upload_fastpath_from_existing_hash_no_file(http: requests.Session, api_base: str): + # Seed a small file first + name = "fastpath_seed.safetensors" + tags = ["models", "checkpoints", "unit-tests"] + meta = {} + files = {"file": (name, b"B" * 1024, "application/octet-stream")} + form = {"tags": json.dumps(tags), "name": name, "user_metadata": json.dumps(meta)} + r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + b1 = r1.json() + assert r1.status_code == 201, b1 + h = b1["asset_hash"] + + # Now POST /api/assets with only hash and no file + files = [ + ("hash", (None, h)), + ("tags", (None, json.dumps(tags))), + ("name", (None, "fastpath_copy.safetensors")), + ("user_metadata", (None, json.dumps({"purpose": "copy"}))), + ] + r2 = http.post(api_base + "/api/assets", files=files, timeout=120) + b2 = r2.json() + assert r2.status_code == 200, b2 # fast path returns 200 with created_new == False + assert b2["created_new"] is False + assert b2["asset_hash"] == h + + +def test_upload_fastpath_with_known_hash_and_file( + http: requests.Session, api_base: str +): + # Seed + files = {"file": ("seed.safetensors", b"C" * 128, "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "seed.safetensors", "user_metadata": json.dumps({})} + r1 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + b1 = r1.json() + assert r1.status_code == 201, b1 + h = b1["asset_hash"] + + # Send both file and hash of existing content -> server must drain file and create from hash (200) + files = {"file": ("ignored.bin", b"ignored" * 10, "application/octet-stream")} + form = {"hash": h, "tags": json.dumps(["models", "checkpoints", "unit-tests", "fp"]), "name": "copy_from_hash.safetensors", "user_metadata": json.dumps({})} + r2 = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + b2 = r2.json() + assert r2.status_code == 200, b2 + assert b2["created_new"] is False + assert b2["asset_hash"] == h + + +def test_upload_multiple_tags_fields_are_merged(http: requests.Session, api_base: str): + data = [ + ("tags", "models,checkpoints"), + ("tags", json.dumps(["unit-tests", "alpha"])), + ("name", "merge.safetensors"), + ("user_metadata", json.dumps({"u": 1})), + ] + files = {"file": ("merge.safetensors", b"B" * 256, "application/octet-stream")} + r1 = http.post(api_base + "/api/assets", data=data, files=files, timeout=120) + created = r1.json() + assert r1.status_code in (200, 201), created + aid = created["id"] + + # Verify all tags are present on the resource + rg = http.get(f"{api_base}/api/assets/{aid}", timeout=120) + detail = rg.json() + assert rg.status_code == 200, detail + tags = set(detail["tags"]) + assert {"models", "checkpoints", "unit-tests", "alpha"}.issubset(tags) + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_concurrent_upload_identical_bytes_different_names( + root: str, + http: requests.Session, + api_base: str, + make_asset_bytes, +): + """ + Two concurrent uploads of identical bytes but different names. + Expect a single Asset (same hash), two AssetInfo rows, and exactly one created_new=True. + """ + scope = f"concupload-{uuid.uuid4().hex[:6]}" + name1, name2 = "cu_a.bin", "cu_b.bin" + data = make_asset_bytes("concurrent", 4096) + tags = [root, "unit-tests", scope] + + def _do_upload(args): + url, form_data, files_data = args + with requests.Session() as s: + return s.post(url, data=form_data, files=files_data, timeout=120) + + url = api_base + "/api/assets" + form1 = {"tags": json.dumps(tags), "name": name1, "user_metadata": json.dumps({})} + files1 = {"file": (name1, data, "application/octet-stream")} + form2 = {"tags": json.dumps(tags), "name": name2, "user_metadata": json.dumps({})} + files2 = {"file": (name2, data, "application/octet-stream")} + + with ThreadPoolExecutor(max_workers=2) as executor: + futures = list(executor.map(_do_upload, [(url, form1, files1), (url, form2, files2)])) + r1, r2 = futures + + b1, b2 = r1.json(), r2.json() + assert r1.status_code in (200, 201), b1 + assert r2.status_code in (200, 201), b2 + assert b1["asset_hash"] == b2["asset_hash"] + assert b1["id"] != b2["id"] + + created_flags = sorted([bool(b1.get("created_new")), bool(b2.get("created_new"))]) + assert created_flags == [False, True] + + rl = http.get( + api_base + "/api/assets", + params={"include_tags": f"unit-tests,{scope}", "sort": "name"}, + timeout=120, + ) + bl = rl.json() + assert rl.status_code == 200, bl + names = [a["name"] for a in bl.get("assets", [])] + assert set([name1, name2]).issubset(names) + + +def test_create_from_hash_endpoint_404(http: requests.Session, api_base: str): + payload = { + "hash": "blake3:" + "0" * 64, + "name": "nonexistent.bin", + "tags": ["models", "checkpoints", "unit-tests"], + } + r = http.post(api_base + "/api/assets/from-hash", json=payload, timeout=120) + body = r.json() + assert r.status_code == 404 + assert body["error"]["code"] == "ASSET_NOT_FOUND" + + +def test_upload_zero_byte_rejected(http: requests.Session, api_base: str): + files = {"file": ("empty.safetensors", b"", "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "empty.safetensors", "user_metadata": json.dumps({})} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "EMPTY_UPLOAD" + + +def test_upload_invalid_root_tag_rejected(http: requests.Session, api_base: str): + files = {"file": ("badroot.bin", b"A" * 64, "application/octet-stream")} + form = {"tags": json.dumps(["not-a-root", "whatever"]), "name": "badroot.bin", "user_metadata": json.dumps({})} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +def test_upload_user_metadata_must_be_json(http: requests.Session, api_base: str): + files = {"file": ("badmeta.bin", b"A" * 128, "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "edge"]), "name": "badmeta.bin", "user_metadata": "{not json}"} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +def test_upload_requires_multipart(http: requests.Session, api_base: str): + r = http.post(api_base + "/api/assets", json={"foo": "bar"}, timeout=120) + body = r.json() + assert r.status_code == 415 + assert body["error"]["code"] == "UNSUPPORTED_MEDIA_TYPE" + + +def test_upload_missing_file_and_hash(http: requests.Session, api_base: str): + files = [ + ("tags", (None, json.dumps(["models", "checkpoints", "unit-tests"]))), + ("name", (None, "x.safetensors")), + ] + r = http.post(api_base + "/api/assets", files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "MISSING_FILE" + + +def test_upload_models_unknown_category(http: requests.Session, api_base: str): + files = {"file": ("m.safetensors", b"A" * 128, "application/octet-stream")} + form = {"tags": json.dumps(["models", "no_such_category", "unit-tests"]), "name": "m.safetensors"} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + assert body["error"]["message"].startswith("unknown models category") + + +def test_upload_models_requires_category(http: requests.Session, api_base: str): + files = {"file": ("nocat.safetensors", b"A" * 64, "application/octet-stream")} + form = {"tags": json.dumps(["models"]), "name": "nocat.safetensors", "user_metadata": json.dumps({})} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + +def test_upload_tags_traversal_guard(http: requests.Session, api_base: str): + files = {"file": ("evil.safetensors", b"A" * 256, "application/octet-stream")} + form = {"tags": json.dumps(["models", "checkpoints", "unit-tests", "..", "zzz"]), "name": "evil.safetensors"} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") + + +@pytest.mark.parametrize("root", ["input", "output"]) +def test_duplicate_upload_same_display_name_does_not_clobber( + root: str, + http: requests.Session, + api_base: str, + asset_factory, + make_asset_bytes, +): + """ + Two uploads use the same tags and the same display name but different bytes. + With hash-based filenames, they must NOT overwrite each other. Both assets + remain accessible and serve their original content. + """ + scope = f"dup-path-{uuid.uuid4().hex[:6]}" + display_name = "same_display.bin" + + d1 = make_asset_bytes(scope + "-v1", 1536) + d2 = make_asset_bytes(scope + "-v2", 2048) + tags = [root, "unit-tests", scope] + + first = asset_factory(display_name, tags, {}, d1) + second = asset_factory(display_name, tags, {}, d2) + + assert first["id"] != second["id"] + assert first["asset_hash"] != second["asset_hash"] # different content + assert first["name"] == second["name"] == display_name + + # Both must be independently retrievable + r1 = http.get(f"{api_base}/api/assets/{first['id']}/content", timeout=120) + b1 = r1.content + assert r1.status_code == 200 + assert b1 == d1 + r2 = http.get(f"{api_base}/api/assets/{second['id']}/content", timeout=120) + b2 = r2.content + assert r2.status_code == 200 + assert b2 == d2 diff --git a/tests-unit/requirements.txt b/tests-unit/requirements.txt index 3a6790ee0..2355b8000 100644 --- a/tests-unit/requirements.txt +++ b/tests-unit/requirements.txt @@ -2,3 +2,4 @@ pytest>=7.8.0 pytest-aiohttp pytest-asyncio websocket-client +blake3