From b23302f372e3e16a00004a9dda623a113fc5ab90 Mon Sep 17 00:00:00 2001 From: Luke Mino-Altherr Date: Wed, 4 Feb 2026 15:04:30 -0800 Subject: [PATCH] refactor(assets): consolidate duplicated query utilities and remove unused code - Extract shared helpers to database/queries/common.py: - MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks, iter_row_chunks - build_visible_owner_clause - Remove duplicate _compute_filename_for_asset, consolidate in path_utils.py - Remove unused get_asset_info_with_tags (duplicated get_asset_detail) - Remove redundant __all__ from cache_state.py - Make internal helpers private (_check_is_scalar) Amp-Thread-ID: https://ampcode.com/threads/T-019c2ad9-9432-7451-94a8-79287dbbb19e Co-authored-by: Amp --- app/assets/api/routes.py | 176 ++++++++++++++++----- app/assets/api/schemas_in.py | 10 +- app/assets/api/upload.py | 32 +++- app/assets/database/models.py | 48 ++++-- app/assets/database/queries/asset.py | 33 ++-- app/assets/database/queries/asset_info.py | 89 +++++------ app/assets/database/queries/cache_state.py | 53 +++---- app/assets/database/queries/common.py | 37 +++++ app/assets/database/queries/tags.py | 103 ++++++------ app/assets/helpers.py | 13 +- app/assets/scanner.py | 116 ++++++++------ app/assets/services/__init__.py | 2 - app/assets/services/asset_management.py | 46 ++---- app/assets/services/ingest.py | 29 ++-- app/assets/services/path_utils.py | 36 ++++- 15 files changed, 514 insertions(+), 309 deletions(-) create mode 100644 app/assets/database/queries/common.py diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 4517c3ef5..23fbe822a 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -38,6 +38,7 @@ USER_MANAGER: user_manager.UserManager | None = None # UUID regex (canonical hyphenated form, case-insensitive) UUID_RE = r"[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F]{12}" + def get_query_dict(request: web.Request) -> dict[str, Any]: """ Gets a dictionary of query parameters from the request. @@ -45,21 +46,33 @@ def get_query_dict(request: web.Request) -> dict[str, Any]: 'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic. """ query_dict = { - key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key) + key: request.query.getall(key) + if len(request.query.getall(key)) > 1 + else request.query.get(key) for key in request.query.keys() } return query_dict + # Note to any custom node developers reading this code: # The assets system is not yet fully implemented, do not rely on the code in /app/assets remaining the same. -def register_assets_system(app: web.Application, user_manager_instance: user_manager.UserManager) -> None: + +def register_assets_system( + app: web.Application, user_manager_instance: user_manager.UserManager +) -> None: global USER_MANAGER USER_MANAGER = user_manager_instance app.add_routes(ROUTES) -def _build_error_response(status: int, code: str, message: str, details: dict | None = None) -> web.Response: - return web.json_response({"error": {"code": code, "message": message, "details": details or {}}}, status=status) + +def _build_error_response( + status: int, code: str, message: str, details: dict | None = None +) -> web.Response: + return web.json_response( + {"error": {"code": code, "message": message, "details": details or {}}}, + status=status, + ) def _build_validation_error_response(code: str, ve: ValidationError) -> web.Response: @@ -79,10 +92,18 @@ def _validate_sort_field(requested: str | None) -> str: async def head_asset_by_hash(request: web.Request) -> web.Response: hash_str = request.match_info.get("hash", "").strip().lower() if not hash_str or ":" not in hash_str: - return _build_error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + return _build_error_response( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) algo, digest = hash_str.split(":", 1) - if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): - return _build_error_response(400, "INVALID_HASH", "hash must be like 'blake3:'") + if ( + algo != "blake3" + or not digest + or any(c for c in digest if c not in "0123456789abcdef") + ): + return _build_error_response( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) exists = asset_exists(hash_str) return web.Response(status=200 if exists else 404) @@ -99,7 +120,11 @@ async def list_assets_route(request: web.Request) -> web.Response: return _build_validation_error_response("INVALID_QUERY", ve) sort = _validate_sort_field(q.sort) - order = "desc" if (q.order or "desc").lower() not in {"asc", "desc"} else q.order.lower() + order = ( + "desc" + if (q.order or "desc").lower() not in {"asc", "desc"} + else q.order.lower() + ) result = list_assets_page( owner_id=USER_MANAGER.get_request_user_id(request), @@ -118,7 +143,9 @@ async def list_assets_route(request: web.Request) -> web.Response: id=item.info.id, name=item.info.name, asset_hash=item.asset.hash if item.asset else None, - size=int(item.asset.size_bytes) if item.asset and item.asset.size_bytes else None, + size=int(item.asset.size_bytes) + if item.asset and item.asset.size_bytes + else None, mime_type=item.asset.mime_type if item.asset else None, tags=item.tags, created_at=item.info.created_at, @@ -148,13 +175,20 @@ async def get_asset_route(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), ) if not result: - return _build_error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found", {"id": asset_info_id}) + return _build_error_response( + 404, + "ASSET_NOT_FOUND", + f"AssetInfo {asset_info_id} not found", + {"id": asset_info_id}, + ) payload = schemas_out.AssetDetail( id=result.info.id, name=result.info.name, asset_hash=result.asset.hash if result.asset else None, - size=int(result.asset.size_bytes) if result.asset and result.asset.size_bytes is not None else None, + size=int(result.asset.size_bytes) + if result.asset and result.asset.size_bytes is not None + else None, mime_type=result.asset.mime_type if result.asset else None, tags=result.tags, user_metadata=result.info.user_metadata or {}, @@ -163,7 +197,9 @@ async def get_asset_route(request: web.Request) -> web.Response: last_access_time=result.info.last_access_time, ) except ValueError as e: - return _build_error_response(404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(e), {"id": asset_info_id} + ) except Exception: logging.exception( "get_asset failed for asset_info_id=%s, owner_id=%s", @@ -193,10 +229,12 @@ async def download_asset_content(request: web.Request) -> web.Response: except NotImplementedError as nie: return _build_error_response(501, "BACKEND_UNSUPPORTED", str(nie)) except FileNotFoundError: - return _build_error_response(404, "FILE_NOT_FOUND", "Underlying file not found on disk.") + return _build_error_response( + 404, "FILE_NOT_FOUND", "Underlying file not found on disk." + ) quoted = (filename or "").replace("\r", "").replace("\n", "").replace('"', "'") - cd = f'{disposition}; filename="{quoted}"; filename*=UTF-8\'\'{urllib.parse.quote(filename)}' + cd = f"{disposition}; filename=\"{quoted}\"; filename*=UTF-8''{urllib.parse.quote(filename)}" file_size = os.path.getsize(abs_path) logging.info( @@ -235,7 +273,9 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response: except ValidationError as ve: return _build_validation_error_response("INVALID_BODY", ve) except Exception: - return _build_error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) result = create_from_hash( hash_str=body.hash, @@ -245,7 +285,9 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), ) if result is None: - return _build_error_response(404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist") + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" + ) payload_out = schemas_out.AssetCreated( id=result.info.id, @@ -282,21 +324,30 @@ async def upload_asset(request: web.Request) -> web.Response: owner_id = USER_MANAGER.get_request_user_id(request) try: - spec = schemas_in.UploadAssetSpec.model_validate({ - "tags": parsed.tags_raw, - "name": parsed.provided_name, - "user_metadata": parsed.user_metadata_raw, - "hash": parsed.provided_hash, - }) + spec = schemas_in.UploadAssetSpec.model_validate( + { + "tags": parsed.tags_raw, + "name": parsed.provided_name, + "user_metadata": parsed.user_metadata_raw, + "hash": parsed.provided_hash, + } + ) except ValidationError as ve: _delete_temp_file_if_exists(parsed.tmp_path) - return _build_error_response(400, "INVALID_BODY", f"Validation failed: {ve.json()}") + return _build_error_response( + 400, "INVALID_BODY", f"Validation failed: {ve.json()}" + ) if spec.tags and spec.tags[0] == "models": - if len(spec.tags) < 2 or spec.tags[1] not in folder_paths.folder_names_and_paths: + if ( + len(spec.tags) < 2 + or spec.tags[1] not in folder_paths.folder_names_and_paths + ): _delete_temp_file_if_exists(parsed.tmp_path) category = spec.tags[1] if len(spec.tags) >= 2 else "" - return _build_error_response(400, "INVALID_BODY", f"unknown models category '{category}'") + return _build_error_response( + 400, "INVALID_BODY", f"unknown models category '{category}'" + ) try: # Fast path: if a valid provided hash exists, create AssetInfo without writing anything @@ -310,12 +361,18 @@ async def upload_asset(request: web.Request) -> web.Response: ) if result is None: _delete_temp_file_if_exists(parsed.tmp_path) - return _build_error_response(404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist") + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"Asset content {spec.hash} does not exist" + ) _delete_temp_file_if_exists(parsed.tmp_path) else: # Otherwise, we must have a temp file path to ingest if not parsed.tmp_path or not os.path.exists(parsed.tmp_path): - return _build_error_response(404, "ASSET_NOT_FOUND", "Provided hash not found and no file uploaded.") + return _build_error_response( + 404, + "ASSET_NOT_FOUND", + "Provided hash not found and no file uploaded.", + ) result = upload_from_temp_path( temp_path=parsed.tmp_path, @@ -365,7 +422,9 @@ async def update_asset_route(request: web.Request) -> web.Response: except ValidationError as ve: return _build_validation_error_response("INVALID_BODY", ve) except Exception: - return _build_error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: result = update_asset_metadata( @@ -383,7 +442,9 @@ async def update_asset_route(request: web.Request) -> web.Response: updated_at=result.info.updated_at, ) except (ValueError, PermissionError) as ve: - return _build_error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id} + ) except Exception: logging.exception( "update_asset failed for asset_info_id=%s, owner_id=%s", @@ -398,7 +459,11 @@ async def update_asset_route(request: web.Request) -> web.Response: async def delete_asset_route(request: web.Request) -> web.Response: asset_info_id = str(uuid.UUID(request.match_info["id"])) delete_content_param = request.query.get("delete_content") - delete_content = True if delete_content_param is None else delete_content_param.lower() not in {"0", "false", "no"} + delete_content = ( + True + if delete_content_param is None + else delete_content_param.lower() not in {"0", "false", "no"} + ) try: deleted = delete_asset_reference( @@ -415,7 +480,9 @@ async def delete_asset_route(request: web.Request) -> web.Response: return _build_error_response(500, "INTERNAL", "Unexpected server error.") if not deleted: - return _build_error_response(404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found.") + return _build_error_response( + 404, "ASSET_NOT_FOUND", f"AssetInfo {asset_info_id} not found." + ) return web.Response(status=204) @@ -430,7 +497,13 @@ async def get_tags(request: web.Request) -> web.Response: query = schemas_in.TagsListQuery.model_validate(query_map) except ValidationError as e: return web.json_response( - {"error": {"code": "INVALID_QUERY", "message": "Invalid query parameters", "details": e.errors()}}, + { + "error": { + "code": "INVALID_QUERY", + "message": "Invalid query parameters", + "details": e.errors(), + } + }, status=400, ) @@ -443,8 +516,13 @@ async def get_tags(request: web.Request) -> web.Response: owner_id=USER_MANAGER.get_request_user_id(request), ) - tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows] - payload = schemas_out.TagsList(tags=tags, total=total, has_more=(query.offset + len(tags)) < total) + tags = [ + schemas_out.TagUsage(name=name, count=count, type=tag_type) + for (name, tag_type, count) in rows + ] + payload = schemas_out.TagsList( + tags=tags, total=total, has_more=(query.offset + len(tags)) < total + ) return web.json_response(payload.model_dump(mode="json")) @@ -455,9 +533,16 @@ async def add_asset_tags(request: web.Request) -> web.Response: json_payload = await request.json() data = schemas_in.TagsAdd.model_validate(json_payload) except ValidationError as ve: - return _build_error_response(400, "INVALID_BODY", "Invalid JSON body for tags add.", {"errors": ve.errors()}) + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags add.", + {"errors": ve.errors()}, + ) except Exception: - return _build_error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: result = apply_tags( @@ -472,7 +557,9 @@ async def add_asset_tags(request: web.Request) -> web.Response: total_tags=result.total_tags, ) except (ValueError, PermissionError) as ve: - return _build_error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id} + ) except Exception: logging.exception( "add_tags_to_asset failed for asset_info_id=%s, owner_id=%s", @@ -491,9 +578,16 @@ async def delete_asset_tags(request: web.Request) -> web.Response: json_payload = await request.json() data = schemas_in.TagsRemove.model_validate(json_payload) except ValidationError as ve: - return _build_error_response(400, "INVALID_BODY", "Invalid JSON body for tags remove.", {"errors": ve.errors()}) + return _build_error_response( + 400, + "INVALID_BODY", + "Invalid JSON body for tags remove.", + {"errors": ve.errors()}, + ) except Exception: - return _build_error_response(400, "INVALID_JSON", "Request body must be valid JSON.") + return _build_error_response( + 400, "INVALID_JSON", "Request body must be valid JSON." + ) try: result = remove_tags( @@ -507,7 +601,9 @@ async def delete_asset_tags(request: web.Request) -> web.Response: total_tags=result.total_tags, ) except ValueError as ve: - return _build_error_response(404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id}) + return _build_error_response( + 404, "ASSET_NOT_FOUND", str(ve), {"id": asset_info_id} + ) except Exception: logging.exception( "remove_tags_from_asset failed for asset_info_id=%s, owner_id=%s", diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index 4c126853e..081918757 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -65,6 +65,7 @@ class ParsedUpload: provided_hash: str | None provided_hash_exists: bool | None + class ListAssetsQuery(BaseModel): include_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list) @@ -76,7 +77,9 @@ class ListAssetsQuery(BaseModel): limit: conint(ge=1, le=500) = 20 offset: conint(ge=0) = 0 - sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" + sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = ( + "created_at" + ) order: Literal["asc", "desc"] = "desc" @field_validator("include_tags", "exclude_tags", mode="before") @@ -218,6 +221,7 @@ class UploadAssetSpec(BaseModel): Files created via this endpoint are stored on disk using the **content hash** as the filename stem and the original extension is preserved when available. """ + model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) tags: list[str] = Field(..., min_length=1) @@ -315,5 +319,7 @@ class UploadAssetSpec(BaseModel): raise ValueError("first tag must be one of: models, input, output") if root == "models": if len(self.tags) < 2: - raise ValueError("models uploads require a category tag as the second tag") + raise ValueError( + "models uploads require a category tag as the second tag" + ) return self diff --git a/app/assets/api/upload.py b/app/assets/api/upload.py index 6dfe4f35a..f5120d07c 100644 --- a/app/assets/api/upload.py +++ b/app/assets/api/upload.py @@ -19,7 +19,11 @@ def normalize_and_validate_hash(s: str) -> str: if ":" not in s: raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:'") algo, digest = s.split(":", 1) - if algo != "blake3" or not digest or any(c for c in digest if c not in "0123456789abcdef"): + if ( + algo != "blake3" + or not digest + or any(c for c in digest if c not in "0123456789abcdef") + ): raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:'") return f"{algo}:{digest}" @@ -42,7 +46,9 @@ async def parse_multipart_upload( UploadError: On validation or I/O errors """ if not (request.content_type or "").lower().startswith("multipart/"): - raise UploadError(415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads.") + raise UploadError( + 415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads." + ) reader = await request.multipart() @@ -68,7 +74,9 @@ async def parse_multipart_upload( try: s = ((await field.text()) or "").strip().lower() except Exception: - raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:'") + raise UploadError( + 400, "INVALID_HASH", "hash must be like 'blake3:'" + ) if s: provided_hash = normalize_and_validate_hash(s) @@ -90,7 +98,9 @@ async def parse_multipart_upload( break file_written += len(chunk) except Exception: - raise UploadError(500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file.") + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file." + ) continue uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads") @@ -108,7 +118,9 @@ async def parse_multipart_upload( file_written += len(chunk) except Exception: _delete_temp_file_if_exists(tmp_path) - raise UploadError(500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file.") + raise UploadError( + 500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file." + ) elif fname == "tags": tags_raw.append((await field.text()) or "") @@ -118,9 +130,15 @@ async def parse_multipart_upload( user_metadata_raw = (await field.text()) or None if not file_present and not (provided_hash and provided_hash_exists): - raise UploadError(400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'.") + raise UploadError( + 400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'." + ) - if file_present and file_written == 0 and not (provided_hash and provided_hash_exists): + if ( + file_present + and file_written == 0 + and not (provided_hash and provided_hash_exists) + ): _delete_temp_file_if_exists(tmp_path) raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.") diff --git a/app/assets/database/models.py b/app/assets/database/models.py index 20ac81e27..67914d9df 100644 --- a/app/assets/database/models.py +++ b/app/assets/database/models.py @@ -27,7 +27,9 @@ from app.database.models import Base, to_dict class Asset(Base): __tablename__ = "assets" - id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) hash: Mapped[str | None] = mapped_column(String(256), nullable=True) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) mime_type: Mapped[str | None] = mapped_column(String(255)) @@ -75,7 +77,9 @@ class AssetCacheState(Base): __tablename__ = "asset_cache_state" id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) - asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) + asset_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False + ) file_path: Mapped[str] = mapped_column(Text, nullable=False) mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False) @@ -85,7 +89,9 @@ class AssetCacheState(Base): __table_args__ = ( Index("ix_asset_cache_state_file_path", "file_path"), Index("ix_asset_cache_state_asset_id", "asset_id"), - CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), + CheckConstraint( + "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg" + ), UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), ) @@ -99,15 +105,29 @@ class AssetCacheState(Base): class AssetInfo(Base): __tablename__ = "assets_info" - id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) + id: Mapped[str] = mapped_column( + String(36), primary_key=True, default=lambda: str(uuid.uuid4()) + ) owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") name: Mapped[str] = mapped_column(String(512), nullable=False) - asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) - preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) - user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) - created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=get_utc_now) - updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=get_utc_now) - last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=get_utc_now) + asset_id: Mapped[str] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False + ) + preview_id: Mapped[str | None] = mapped_column( + String(36), ForeignKey("assets.id", ondelete="SET NULL") + ) + user_metadata: Mapped[dict[str, Any] | None] = mapped_column( + JSON(none_as_null=True) + ) + created_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + updated_at: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) + last_access_time: Mapped[datetime] = mapped_column( + DateTime(timezone=False), nullable=False, default=get_utc_now + ) asset: Mapped[Asset] = relationship( "Asset", @@ -143,7 +163,9 @@ class AssetInfo(Base): ) __table_args__ = ( - UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), + UniqueConstraint( + "asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name" + ), Index("ix_assets_info_owner_name", "owner_id", "name"), Index("ix_assets_info_owner_id", "owner_id"), Index("ix_assets_info_asset_id", "asset_id"), @@ -225,9 +247,7 @@ class Tag(Base): overlaps="asset_info_links,tag_links,tags,asset_info", ) - __table_args__ = ( - Index("ix_tags_tag_type", "tag_type"), - ) + __table_args__ = (Index("ix_tags_tag_type", "tag_type"),) def __repr__(self) -> str: return f"" diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py index 174b30837..6913fb501 100644 --- a/app/assets/database/queries/asset.py +++ b/app/assets/database/queries/asset.py @@ -4,17 +4,7 @@ from sqlalchemy.dialects import sqlite from sqlalchemy.orm import Session from app.assets.database.models import Asset - -MAX_BIND_PARAMS = 800 - - -def _calculate_rows_per_statement(cols: int) -> int: - return max(1, MAX_BIND_PARAMS // max(1, cols)) - - -def _iter_chunks(seq, n: int): - for i in range(0, len(seq), n): - yield seq[i : i + n] +from app.assets.database.queries.common import calculate_rows_per_statement, iter_chunks def asset_exists_by_hash( @@ -26,7 +16,10 @@ def asset_exists_by_hash( """ row = ( session.execute( - select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1) + select(sa.literal(True)) + .select_from(Asset) + .where(Asset.hash == asset_hash) + .limit(1) ) ).first() return row is not None @@ -37,8 +30,10 @@ def get_asset_by_hash( asset_hash: str, ) -> Asset | None: return ( - session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) - ).scalars().first() + (session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))) + .scalars() + .first() + ) def upsert_asset( @@ -60,9 +55,11 @@ def upsert_asset( res = session.execute(ins) created = int(res.rowcount or 0) > 0 - asset = session.execute( - select(Asset).where(Asset.hash == asset_hash).limit(1) - ).scalars().first() + asset = ( + session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)) + .scalars() + .first() + ) if not asset: raise RuntimeError("Asset row not found after upsert.") @@ -89,5 +86,5 @@ def bulk_insert_assets( if not rows: return ins = sqlite.insert(Asset) - for chunk in _iter_chunks(rows, _calculate_rows_per_statement(5)): + for chunk in iter_chunks(rows, calculate_rows_per_statement(5)): session.execute(ins, chunk) diff --git a/app/assets/database/queries/asset_info.py b/app/assets/database/queries/asset_info.py index fce2a71ef..23716a929 100644 --- a/app/assets/database/queries/asset_info.py +++ b/app/assets/database/queries/asset_info.py @@ -16,10 +16,16 @@ from app.assets.database.models import ( AssetInfoTag, Tag, ) +from app.assets.database.queries.common import ( + MAX_BIND_PARAMS, + build_visible_owner_clause, + calculate_rows_per_statement, + iter_chunks, +) from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags -def check_is_scalar(v): +def _check_is_scalar(v): if v is None: return True if isinstance(v, bool): @@ -33,8 +39,12 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict: """Convert a scalar value to a typed projection row.""" if value is None: return { - "key": key, "ordinal": ordinal, - "val_str": None, "val_num": None, "val_bool": None, "val_json": None + "key": key, + "ordinal": ordinal, + "val_str": None, + "val_num": None, + "val_bool": None, + "val_json": None, } if isinstance(value, bool): return {"key": key, "ordinal": ordinal, "val_bool": bool(value)} @@ -55,35 +65,16 @@ def convert_metadata_to_rows(key: str, value) -> list[dict]: if value is None: return [_scalar_to_row(key, 0, None)] - if check_is_scalar(value): + if _check_is_scalar(value): return [_scalar_to_row(key, 0, value)] if isinstance(value, list): - if all(check_is_scalar(x) for x in value): + if all(_check_is_scalar(x) for x in value): return [_scalar_to_row(key, i, x) for i, x in enumerate(value)] return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)] return [{"key": key, "ordinal": 0, "val_json": value}] -MAX_BIND_PARAMS = 800 - - -def _calculate_rows_per_statement(cols: int) -> int: - return max(1, MAX_BIND_PARAMS // max(1, cols)) - - -def _iter_chunks(seq, n: int): - for i in range(0, len(seq), n): - yield seq[i : i + n] - - -def _build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: - """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" - owner_id = (owner_id or "").strip() - if owner_id == "": - return AssetInfo.owner_id == "" - return AssetInfo.owner_id.in_(["", owner_id]) - def _apply_tag_filters( stmt: sa.sql.Select, @@ -229,15 +220,19 @@ def get_or_create_asset_info( if info: return info, True - existing = session.execute( - select(AssetInfo) - .where( - AssetInfo.asset_id == asset_id, - AssetInfo.name == name, - AssetInfo.owner_id == owner_id, + existing = ( + session.execute( + select(AssetInfo) + .where( + AssetInfo.asset_id == asset_id, + AssetInfo.name == name, + AssetInfo.owner_id == owner_id, + ) + .limit(1) ) - .limit(1) - ).unique().scalar_one_or_none() + .unique() + .scalar_one_or_none() + ) if not existing: raise RuntimeError("Failed to find AssetInfo after insert conflict.") return existing, False @@ -274,7 +269,7 @@ def list_asset_infos_page( select(AssetInfo) .join(Asset, Asset.id == AssetInfo.asset_id) .options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags)) - .where(_build_visible_owner_clause(owner_id)) + .where(build_visible_owner_clause(owner_id)) ) if name_contains: @@ -302,7 +297,7 @@ def list_asset_infos_page( select(sa.func.count()) .select_from(AssetInfo) .join(Asset, Asset.id == AssetInfo.asset_id) - .where(_build_visible_owner_clause(owner_id)) + .where(build_visible_owner_clause(owner_id)) ) if name_contains: escaped, esc = escape_sql_like_string(name_contains) @@ -341,7 +336,7 @@ def fetch_asset_info_asset_and_tags( .join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True) .where( AssetInfo.id == asset_info_id, - _build_visible_owner_clause(owner_id), + build_visible_owner_clause(owner_id), ) .options(noload(AssetInfo.tags)) .order_by(Tag.name.asc()) @@ -371,7 +366,7 @@ def fetch_asset_info_and_asset( .join(Asset, Asset.id == AssetInfo.asset_id) .where( AssetInfo.id == asset_info_id, - _build_visible_owner_clause(owner_id), + build_visible_owner_clause(owner_id), ) .limit(1) .options(noload(AssetInfo.tags)) @@ -393,7 +388,9 @@ def update_asset_info_access_time( stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id) if only_if_newer: stmt = stmt.where( - sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts) + sa.or_( + AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts + ) ) session.execute(stmt.values(last_access_time=ts)) @@ -420,9 +417,7 @@ def update_asset_info_updated_at( """Update the updated_at timestamp of an AssetInfo.""" ts = ts or get_utc_now() session.execute( - sa.update(AssetInfo) - .where(AssetInfo.id == asset_info_id) - .values(updated_at=ts) + sa.update(AssetInfo).where(AssetInfo.id == asset_info_id).values(updated_at=ts) ) @@ -439,7 +434,9 @@ def set_asset_info_metadata( info.updated_at = get_utc_now() session.flush() - session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id)) + session.execute( + delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id) + ) session.flush() if not user_metadata: @@ -471,7 +468,7 @@ def delete_asset_info_by_id( ) -> bool: stmt = sa.delete(AssetInfo).where( AssetInfo.id == asset_info_id, - _build_visible_owner_clause(owner_id), + build_visible_owner_clause(owner_id), ) return int((session.execute(stmt)).rowcount or 0) > 0 @@ -511,7 +508,7 @@ def bulk_insert_asset_infos_ignore_conflicts( ins = sqlite.insert(AssetInfo).on_conflict_do_nothing( index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name] ) - for chunk in _iter_chunks(rows, _calculate_rows_per_statement(9)): + for chunk in iter_chunks(rows, calculate_rows_per_statement(9)): session.execute(ins, chunk) @@ -524,9 +521,7 @@ def get_asset_info_ids_by_ids( return set() found: set[str] = set() - for chunk in _iter_chunks(info_ids, MAX_BIND_PARAMS): - result = session.execute( - select(AssetInfo.id).where(AssetInfo.id.in_(chunk)) - ) + for chunk in iter_chunks(info_ids, MAX_BIND_PARAMS): + result = session.execute(select(AssetInfo.id).where(AssetInfo.id.in_(chunk))) found.update(result.scalars().all()) return found diff --git a/app/assets/database/queries/cache_state.py b/app/assets/database/queries/cache_state.py index 4c618751a..5a304a641 100644 --- a/app/assets/database/queries/cache_state.py +++ b/app/assets/database/queries/cache_state.py @@ -7,34 +7,13 @@ from sqlalchemy.dialects import sqlite from sqlalchemy.orm import Session from app.assets.database.models import Asset, AssetCacheState, AssetInfo +from app.assets.database.queries.common import ( + MAX_BIND_PARAMS, + calculate_rows_per_statement, + iter_chunks, +) from app.assets.helpers import escape_sql_like_string -MAX_BIND_PARAMS = 800 - -__all__ = [ - "CacheStateRow", - "list_cache_states_by_asset_id", - "upsert_cache_state", - "delete_cache_states_outside_prefixes", - "get_orphaned_seed_asset_ids", - "delete_assets_by_ids", - "get_cache_states_for_prefixes", - "bulk_set_needs_verify", - "delete_cache_states_by_ids", - "delete_orphaned_seed_asset", - "bulk_insert_cache_states_ignore_conflicts", - "get_cache_states_by_paths_and_asset_ids", -] - - -def _calculate_rows_per_statement(cols: int) -> int: - return max(1, MAX_BIND_PARAMS // max(1, cols)) - - -def _iter_chunks(seq, n: int): - for i in range(0, len(seq), n): - yield seq[i : i + n] - class CacheStateRow(NamedTuple): """Row from cache state query with joined asset data.""" @@ -52,12 +31,16 @@ def list_cache_states_by_asset_id( session: Session, *, asset_id: str ) -> Sequence[AssetCacheState]: return ( - session.execute( - select(AssetCacheState) - .where(AssetCacheState.asset_id == asset_id) - .order_by(AssetCacheState.id.asc()) + ( + session.execute( + select(AssetCacheState) + .where(AssetCacheState.asset_id == asset_id) + .order_by(AssetCacheState.id.asc()) + ) ) - ).scalars().all() + .scalars() + .all() + ) def upsert_cache_state( @@ -100,7 +83,9 @@ def upsert_cache_state( return False, updated -def delete_cache_states_outside_prefixes(session: Session, valid_prefixes: list[str]) -> int: +def delete_cache_states_outside_prefixes( + session: Session, valid_prefixes: list[str] +) -> int: """Delete cache states with file_path not matching any of the valid prefixes. Args: @@ -261,7 +246,7 @@ def bulk_insert_cache_states_ignore_conflicts( ins = sqlite.insert(AssetCacheState).on_conflict_do_nothing( index_elements=[AssetCacheState.file_path] ) - for chunk in _iter_chunks(rows, _calculate_rows_per_statement(3)): + for chunk in iter_chunks(rows, calculate_rows_per_statement(3)): session.execute(ins, chunk) @@ -283,7 +268,7 @@ def get_cache_states_by_paths_and_asset_ids( paths = list(path_to_asset.keys()) winners: set[str] = set() - for chunk in _iter_chunks(paths, MAX_BIND_PARAMS): + for chunk in iter_chunks(paths, MAX_BIND_PARAMS): result = session.execute( select(AssetCacheState.file_path).where( AssetCacheState.file_path.in_(chunk), diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py new file mode 100644 index 000000000..4086cea56 --- /dev/null +++ b/app/assets/database/queries/common.py @@ -0,0 +1,37 @@ +"""Shared utilities for database query modules.""" + +from typing import Iterable + +import sqlalchemy as sa + +from app.assets.database.models import AssetInfo + +MAX_BIND_PARAMS = 800 + + +def calculate_rows_per_statement(cols: int) -> int: + """Calculate how many rows can fit in one statement given column count.""" + return max(1, MAX_BIND_PARAMS // max(1, cols)) + + +def iter_chunks(seq, n: int): + """Yield successive n-sized chunks from seq.""" + for i in range(0, len(seq), n): + yield seq[i : i + n] + + +def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]: + """Yield chunks of rows sized to fit within bind param limits.""" + if not rows: + return [] + rows_per_stmt = max(1, MAX_BIND_PARAMS // max(1, cols_per_row)) + for i in range(0, len(rows), rows_per_stmt): + yield rows[i : i + rows_per_stmt] + + +def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: + """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" + owner_id = (owner_id or "").strip() + if owner_id == "": + return AssetInfo.owner_id == "" + return AssetInfo.owner_id.in_(["", owner_id]) diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index 7733d6e2b..53548b383 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -7,6 +7,10 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from app.assets.database.models import AssetInfo, AssetInfoMeta, AssetInfoTag, Tag +from app.assets.database.queries.common import ( + build_visible_owner_clause, + iter_row_chunks, +) from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags @@ -27,30 +31,10 @@ class SetTagsDict(TypedDict): removed: list[str] total: list[str] -MAX_BIND_PARAMS = 800 - -def _calculate_rows_per_statement(cols: int) -> int: - return max(1, MAX_BIND_PARAMS // max(1, cols)) - - -def _iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]: - if not rows: - return [] - rows_per_stmt = max(1, MAX_BIND_PARAMS // max(1, cols_per_row)) - for i in range(0, len(rows), rows_per_stmt): - yield rows[i : i + rows_per_stmt] - - -def _build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement: - """Build owner visibility predicate for reads. Owner-less rows are visible to everyone.""" - owner_id = (owner_id or "").strip() - if owner_id == "": - return AssetInfo.owner_id == "" - return AssetInfo.owner_id.in_(["", owner_id]) - - -def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None: +def ensure_tags_exist( + session: Session, names: Iterable[str], tag_type: str = "user" +) -> None: wanted = normalize_tags(list(names)) if not wanted: return @@ -65,9 +49,12 @@ def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "u def get_asset_tags(session: Session, asset_info_id: str) -> list[str]: return [ - tag_name for (tag_name,) in ( + tag_name + for (tag_name,) in ( session.execute( - select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + select(AssetInfoTag.tag_name).where( + AssetInfoTag.asset_info_id == asset_info_id + ) ) ).all() ] @@ -82,8 +69,13 @@ def set_asset_info_tags( desired = normalize_tags(tags) current = set( - tag_name for (tag_name,) in ( - session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)) + tag_name + for (tag_name,) in ( + session.execute( + select(AssetInfoTag.tag_name).where( + AssetInfoTag.asset_info_id == asset_info_id + ) + ) ).all() ) @@ -92,16 +84,25 @@ def set_asset_info_tags( if to_add: ensure_tags_exist(session, to_add, tag_type="user") - session.add_all([ - AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=get_utc_now()) - for t in to_add - ]) + session.add_all( + [ + AssetInfoTag( + asset_info_id=asset_info_id, + tag_name=t, + origin=origin, + added_at=get_utc_now(), + ) + for t in to_add + ] + ) session.flush() if to_remove: session.execute( - delete(AssetInfoTag) - .where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove)) + delete(AssetInfoTag).where( + AssetInfoTag.asset_info_id == asset_info_id, + AssetInfoTag.tag_name.in_(to_remove), + ) ) session.flush() @@ -133,7 +134,9 @@ def add_tags_to_asset_info( tag_name for (tag_name,) in ( session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + sa.select(AssetInfoTag.tag_name).where( + AssetInfoTag.asset_info_id == asset_info_id + ) ) ).all() } @@ -185,7 +188,9 @@ def remove_tags_from_asset_info( tag_name for (tag_name,) in ( session.execute( - sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id) + sa.select(AssetInfoTag.tag_name).where( + AssetInfoTag.asset_info_id == asset_info_id + ) ) ).all() } @@ -195,8 +200,7 @@ def remove_tags_from_asset_info( if to_remove: session.execute( - delete(AssetInfoTag) - .where( + delete(AssetInfoTag).where( AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove), ) @@ -222,7 +226,10 @@ def add_missing_tag_for_asset_id( .where(AssetInfo.asset_id == asset_id) .where( sa.not_( - sa.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing")) + sa.exists().where( + (AssetInfoTag.asset_info_id == AssetInfo.id) + & (AssetInfoTag.tag_name == "missing") + ) ) ) ) @@ -232,7 +239,9 @@ def add_missing_tag_for_asset_id( ["asset_info_id", "tag_name", "origin", "added_at"], select_rows, ) - .on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name]) + .on_conflict_do_nothing( + index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name] + ) ) @@ -242,7 +251,9 @@ def remove_missing_tag_for_asset_id( ) -> None: session.execute( sa.delete(AssetInfoTag).where( - AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)), + AssetInfoTag.asset_info_id.in_( + sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id) + ), AssetInfoTag.tag_name == "missing", ) ) @@ -264,7 +275,7 @@ def list_tags_with_usage( ) .select_from(AssetInfoTag) .join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id) - .where(_build_visible_owner_clause(owner_id)) + .where(build_visible_owner_clause(owner_id)) .group_by(AssetInfoTag.tag_name) .subquery() ) @@ -323,12 +334,16 @@ def bulk_insert_tags_and_meta( ins_tags = sqlite.insert(AssetInfoTag).on_conflict_do_nothing( index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name] ) - for chunk in _iter_row_chunks(tag_rows, cols_per_row=4): + for chunk in iter_row_chunks(tag_rows, cols_per_row=4): session.execute(ins_tags, chunk) if meta_rows: ins_meta = sqlite.insert(AssetInfoMeta).on_conflict_do_nothing( - index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal] + index_elements=[ + AssetInfoMeta.asset_info_id, + AssetInfoMeta.key, + AssetInfoMeta.ordinal, + ] ) - for chunk in _iter_row_chunks(meta_rows, cols_per_row=7): + for chunk in iter_row_chunks(meta_rows, cols_per_row=7): session.execute(ins_meta, chunk) diff --git a/app/assets/helpers.py b/app/assets/helpers.py index 65b3a0d80..685edaf88 100644 --- a/app/assets/helpers.py +++ b/app/assets/helpers.py @@ -10,7 +10,11 @@ def select_best_live_path(states: Sequence) -> str: 2) Otherwise, pick the first path that exists. 3) Otherwise return empty string. """ - alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)] + alive = [ + s + for s in states + if getattr(s, "file_path", None) and os.path.isfile(s.file_path) + ] if not alive: return "" for s in alive: @@ -19,7 +23,11 @@ def select_best_live_path(states: Sequence) -> str: return alive[0].file_path -ALLOWED_ROOTS: tuple[Literal["models", "input", "output"], ...] = ("models", "input", "output") +ALLOWED_ROOTS: tuple[Literal["models", "input", "output"], ...] = ( + "models", + "input", + "output", +) def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]: @@ -43,4 +51,3 @@ def normalize_tags(tags: list[str] | None) -> list[str]: - Removing duplicates. """ return [t.strip().lower() for t in (tags or []) if (t or "").strip()] - diff --git a/app/assets/scanner.py b/app/assets/scanner.py index 499d06a2a..a5cceb88a 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -44,7 +44,9 @@ def verify_asset_file_unchanged( ) -> bool: if mtime_db is None: return False - actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)) + actual_mtime_ns = getattr( + stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000) + ) if int(mtime_db) != int(actual_mtime_ns): return False sz = int(size_db or 0) @@ -58,7 +60,9 @@ def list_files_recursively(base_dir: str) -> list[str]: base_abs = os.path.abspath(base_dir) if not os.path.isdir(base_abs): return out - for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False): + for dirpath, _subdirs, filenames in os.walk( + base_abs, topdown=True, followlinks=False + ): for name in filenames: out.append(os.path.abspath(os.path.join(dirpath, name))) return out @@ -141,18 +145,22 @@ def _batch_insert_assets_from_paths( path_list.append(ap) path_to_asset[ap] = aid - asset_rows.append({ - "id": aid, - "hash": None, - "size_bytes": sp["size_bytes"], - "mime_type": None, - "created_at": now, - }) - state_rows.append({ - "asset_id": aid, - "file_path": ap, - "mtime_ns": sp["mtime_ns"], - }) + asset_rows.append( + { + "id": aid, + "hash": None, + "size_bytes": sp["size_bytes"], + "mime_type": None, + "created_at": now, + } + ) + state_rows.append( + { + "asset_id": aid, + "file_path": ap, + "mtime_ns": sp["mtime_ns"], + } + ) asset_to_info[aid] = { "id": iid, "owner_id": owner_id, @@ -179,7 +187,11 @@ def _batch_insert_assets_from_paths( delete_assets_by_ids(session, lost_assets) if not winners_by_path: - return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)} + return { + "inserted_infos": 0, + "won_states": 0, + "lost_states": len(losers_by_path), + } winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path] db_info_rows = [ @@ -209,22 +221,26 @@ def _batch_insert_assets_from_paths( if iid not in inserted_info_ids: continue for t in row["_tags"]: - tag_rows.append({ - "asset_info_id": iid, - "tag_name": t, - "origin": "automatic", - "added_at": now, - }) + tag_rows.append( + { + "asset_info_id": iid, + "tag_name": t, + "origin": "automatic", + "added_at": now, + } + ) if row["_filename"]: - meta_rows.append({ - "asset_info_id": iid, - "key": "filename", - "ordinal": 0, - "val_str": row["_filename"], - "val_num": None, - "val_bool": None, - "val_json": None, - }) + meta_rows.append( + { + "asset_info_id": iid, + "key": "filename", + "ordinal": 0, + "val_str": row["_filename"], + "val_num": None, + "val_bool": None, + "val_json": None, + } + ) bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows) @@ -299,13 +315,15 @@ def sync_cache_states_with_filesystem( except OSError: exists = False - acc["states"].append({ - "sid": row.state_id, - "fp": row.file_path, - "exists": exists, - "fast_ok": fast_ok, - "needs_verify": row.needs_verify, - }) + acc["states"].append( + { + "sid": row.state_id, + "fp": row.file_path, + "exists": exists, + "fast_ok": fast_ok, + "needs_verify": row.needs_verify, + } + ) to_set_verify: list[int] = [] to_clear_verify: list[int] = [] @@ -425,14 +443,18 @@ def _build_asset_specs( if not stat_p.st_size: continue name, tags = get_name_and_tags_from_asset_path(abs_p) - specs.append({ - "abs_path": abs_p, - "size_bytes": stat_p.st_size, - "mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)), - "info_name": name, - "tags": tags, - "fname": compute_relative_filename(abs_p), - }) + specs.append( + { + "abs_path": abs_p, + "size_bytes": stat_p.st_size, + "mtime_ns": getattr( + stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000) + ), + "info_name": name, + "tags": tags, + "fname": compute_relative_filename(abs_p), + } + ) tag_pool.update(tags) return specs, tag_pool, skipped @@ -463,9 +485,7 @@ def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> No for r in roots: existing_paths.update(_sync_root_safely(r)) - all_prefixes = [ - os.path.abspath(p) for r in roots for p in get_prefixes_for_root(r) - ] + all_prefixes = [os.path.abspath(p) for r in roots for p in get_prefixes_for_root(r)] orphans_pruned = _prune_orphans_safely(all_prefixes) paths = _collect_paths_for_roots(roots) diff --git a/app/assets/services/__init__.py b/app/assets/services/__init__.py index 5b1f8f1ab..a35c66b8a 100644 --- a/app/assets/services/__init__.py +++ b/app/assets/services/__init__.py @@ -3,7 +3,6 @@ from app.assets.services.asset_management import ( delete_asset_reference, get_asset_by_hash, get_asset_detail, - get_asset_info_with_tags, list_assets_page, resolve_asset_for_download, set_asset_preview, @@ -49,7 +48,6 @@ __all__ = [ "asset_exists", "get_asset_by_hash", "get_asset_detail", - "get_asset_info_with_tags", "list_assets_page", "resolve_asset_for_download", "update_asset_metadata", diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 97bea19b4..3925bb0b8 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -3,7 +3,6 @@ import mimetypes import os from typing import Sequence -from sqlalchemy.orm import Session from app.assets.database.models import Asset from app.assets.database.queries import ( @@ -14,7 +13,6 @@ from app.assets.database.queries import ( fetch_asset_info_asset_and_tags, get_asset_by_hash as queries_get_asset_by_hash, get_asset_info_by_id, - get_asset_tags, list_asset_infos_page, list_cache_states_by_asset_id, set_asset_info_metadata, @@ -25,7 +23,7 @@ from app.assets.database.queries import ( update_asset_info_updated_at, ) from app.assets.helpers import select_best_live_path -from app.assets.services.path_utils import compute_relative_filename +from app.assets.services.path_utils import compute_filename_for_asset from app.assets.services.schemas import ( AssetData, AssetDetailResult, @@ -80,7 +78,7 @@ def update_asset_metadata( update_asset_info_name(session, asset_info_id=asset_info_id, name=name) touched = True - computed_filename = _compute_filename_for_asset(session, info.asset_id) + computed_filename = compute_filename_for_asset(session, info.asset_id) new_meta: dict | None = None if user_metadata is not None: @@ -138,7 +136,9 @@ def delete_asset_reference( info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id) asset_id = info_row.asset_id if info_row else None - deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id) + deleted = delete_asset_info_by_id( + session, asset_info_id=asset_info_id, owner_id=owner_id + ) if not deleted: session.commit() return False @@ -154,7 +154,9 @@ def delete_asset_reference( # Orphaned asset - delete it and its files states = list_cache_states_by_asset_id(session, asset_id=asset_id) - file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)] + file_paths = [ + s.file_path for s in (states or []) if getattr(s, "file_path", None) + ] asset_row = session.get(Asset, asset_id) if asset_row is not None: @@ -206,11 +208,6 @@ def set_asset_preview( return detail -def _compute_filename_for_asset(session: Session, asset_id: str) -> str | None: - primary_path = select_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset_id)) - return compute_relative_filename(primary_path) if primary_path else None - - def asset_exists(asset_hash: str) -> bool: with create_session() as session: return asset_exists_by_hash(session, asset_hash=asset_hash) @@ -265,7 +262,9 @@ def resolve_asset_for_download( owner_id: str = "", ) -> DownloadResolutionResult: with create_session() as session: - pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) + pair = fetch_asset_info_and_asset( + session, asset_info_id=asset_info_id, owner_id=owner_id + ) if not pair: raise ValueError(f"AssetInfo {asset_info_id} not found") @@ -278,27 +277,14 @@ def resolve_asset_for_download( update_asset_info_access_time(session, asset_info_id=asset_info_id) session.commit() - ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream" + ctype = ( + asset.mime_type + or mimetypes.guess_type(info.name or abs_path)[0] + or "application/octet-stream" + ) download_name = info.name or os.path.basename(abs_path) return DownloadResolutionResult( abs_path=abs_path, content_type=ctype, download_name=download_name, ) - - -def get_asset_info_with_tags( - asset_info_id: str, - owner_id: str = "", -) -> AssetDetailResult | None: - with create_session() as session: - pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id) - if not pair: - return None - info, asset = pair - tags = get_asset_tags(session, asset_info_id=asset_info_id) - return AssetDetailResult( - info=extract_info_data(info), - asset=extract_asset_data(asset), - tags=tags, - ) diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index f6963b44b..461d226a6 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -15,7 +15,6 @@ from app.assets.database.queries import ( get_asset_by_hash, get_asset_tags, get_or_create_asset_info, - list_cache_states_by_asset_id, remove_missing_tag_for_asset_id, set_asset_info_metadata, set_asset_info_tags, @@ -23,9 +22,9 @@ from app.assets.database.queries import ( upsert_asset, upsert_cache_state, ) -from app.assets.helpers import normalize_tags, select_best_live_path +from app.assets.helpers import normalize_tags from app.assets.services.path_utils import ( - compute_relative_filename, + compute_filename_for_asset, resolve_destination_from_tags, validate_path_within_base, ) @@ -92,7 +91,9 @@ def ingest_file_from_path( if info_created: asset_info_id = info.id else: - update_asset_info_timestamps(session, asset_info=info, preview_id=preview_id) + update_asset_info_timestamps( + session, asset_info=info, preview_id=preview_id + ) asset_info_id = info.id norm = normalize_tags(list(tags)) @@ -165,7 +166,7 @@ def register_existing_asset( return result new_meta = dict(user_metadata or {}) - computed_filename = _compute_filename_for_asset(session, asset.id) + computed_filename = compute_filename_for_asset(session, asset.id) if computed_filename: new_meta["filename"] = computed_filename @@ -199,18 +200,14 @@ def register_existing_asset( def _validate_tags_exist(session: Session, tags: list[str]) -> None: existing_tag_names = set( - name for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all() + name + for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all() ) missing = [t for t in tags if t not in existing_tag_names] if missing: raise ValueError(f"Unknown tags: {missing}") -def _compute_filename_for_asset(session: Session, asset_id: str) -> str | None: - primary_path = select_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset_id)) - return compute_relative_filename(primary_path) if primary_path else None - - def _update_metadata_with_filename( session: Session, asset_info_id: str, @@ -218,7 +215,7 @@ def _update_metadata_with_filename( info: AssetInfo, user_metadata: UserMetadata, ) -> None: - computed_filename = _compute_filename_for_asset(session, asset_id) + computed_filename = compute_filename_for_asset(session, asset_id) current_meta = info.user_metadata or {} new_meta = dict(current_meta) @@ -346,7 +343,9 @@ def upload_from_temp_path( raise RuntimeError("failed to create asset metadata") with create_session() as session: - pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id) + pair = fetch_asset_info_and_asset( + session, asset_info_id=info_id, owner_id=owner_id + ) if not pair: raise RuntimeError("inconsistent DB state after ingest") info, asset = pair @@ -376,7 +375,9 @@ def create_from_hash( result = register_existing_asset( asset_hash=canonical, - name=_sanitize_filename(name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical), + name=_sanitize_filename( + name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical + ), user_metadata=user_metadata or {}, tags=tags or [], tag_origin="manual", diff --git a/app/assets/services/path_utils.py b/app/assets/services/path_utils.py index cd2f87d6c..ade4bb0bd 100644 --- a/app/assets/services/path_utils.py +++ b/app/assets/services/path_utils.py @@ -15,7 +15,10 @@ def get_comfy_models_folders() -> list[tuple[str, list[str]]]: targets: list[tuple[str, list[str]]] = [] models_root = os.path.abspath(folder_paths.models_dir) for name, values in folder_paths.folder_names_and_paths.items(): - paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI + paths, _exts = ( + values[0], + values[1], + ) # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths): targets.append((name, paths)) return targets @@ -37,7 +40,9 @@ def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]: raw_subdirs = tags[2:] else: base_dir = os.path.abspath( - folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory() + folder_paths.get_input_directory() + if root == "input" + else folder_paths.get_output_directory() ) raw_subdirs = tags[1:] for i in raw_subdirs: @@ -84,7 +89,9 @@ def compute_relative_filename(file_path: str) -> str | None: return "/".join(parts) # input/output: keep all parts -def get_asset_category_and_relative_path(file_path: str) -> tuple[Literal["input", "output", "models"], str]: +def get_asset_category_and_relative_path( + file_path: str, +) -> tuple[Literal["input", "output", "models"], str]: """Given an absolute or relative file path, determine which root category the path belongs to: - 'input' if the file resides under `folder_paths.get_input_directory()` - 'output' if the file resides under `folder_paths.get_output_directory()` @@ -107,7 +114,9 @@ def get_asset_category_and_relative_path(file_path: str) -> tuple[Literal["input return False def _compute_relative(child: str, parent: str) -> str: - return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep) + return os.path.relpath( + os.path.join(os.sep, os.path.relpath(child, parent)), os.sep + ) # 1) input input_base = os.path.abspath(folder_paths.get_input_directory()) @@ -135,7 +144,20 @@ def get_asset_category_and_relative_path(file_path: str) -> tuple[Literal["input combined = os.path.join(bucket, rel_inside) return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep) - raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}") + raise ValueError( + f"Path is not within input, output, or configured model bases: {file_path}" + ) + + +def compute_filename_for_asset(session, asset_id: str) -> str | None: + """Compute the relative filename for an asset from its best live cache state path.""" + from app.assets.database.queries import list_cache_states_by_asset_id + from app.assets.helpers import select_best_live_path + + primary_path = select_best_live_path( + list_cache_states_by_asset_id(session, asset_id=asset_id) + ) + return compute_relative_filename(primary_path) if primary_path else None def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: @@ -156,5 +178,7 @@ def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]: """ root_category, some_path = get_asset_category_and_relative_path(file_path) p = Path(some_path) - parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)] + parent_parts = [ + part for part in p.parent.parts if part not in (".", "..", p.anchor) + ] return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))