diff --git a/.github/scripts/check-ai-co-authors.sh b/.github/scripts/check-ai-co-authors.sh new file mode 100755 index 000000000..842b1f2d8 --- /dev/null +++ b/.github/scripts/check-ai-co-authors.sh @@ -0,0 +1,103 @@ +#!/usr/bin/env bash +# Checks pull request commits for AI agent Co-authored-by trailers. +# Exits non-zero when any are found and prints fix instructions. +set -euo pipefail + +base_sha="${1:?usage: check-ai-co-authors.sh }" +head_sha="${2:?usage: check-ai-co-authors.sh }" + +# Known AI coding-agent trailer patterns (case-insensitive). +# Each entry is an extended-regex fragment matched against Co-authored-by lines. +AGENT_PATTERNS=( + # Anthropic — Claude Code / Amp + 'noreply@anthropic\.com' + # Cursor + 'cursoragent@cursor\.com' + # GitHub Copilot + 'copilot-swe-agent\[bot\]' + 'copilot@github\.com' + # OpenAI Codex + 'noreply@openai\.com' + 'codex@openai\.com' + # Aider + 'aider@aider\.chat' + # Google — Gemini / Jules + 'gemini@google\.com' + 'jules@google\.com' + # Windsurf / Codeium + '@codeium\.com' + # Devin + 'devin-ai-integration\[bot\]' + 'devin@cognition\.ai' + 'devin@cognition-labs\.com' + # Amazon Q Developer + 'amazon-q-developer' + '@amazon\.com.*[Qq].[Dd]eveloper' + # Cline + 'cline-bot' + 'cline@cline\.ai' + # Continue + 'continue-agent' + 'continue@continue\.dev' + # Sourcegraph + 'noreply@sourcegraph\.com' + # Generic catch-alls for common agent name patterns + 'Co-authored-by:.*\b[Cc]laude\b' + 'Co-authored-by:.*\b[Cc]opilot\b' + 'Co-authored-by:.*\b[Cc]ursor\b' + 'Co-authored-by:.*\b[Cc]odex\b' + 'Co-authored-by:.*\b[Gg]emini\b' + 'Co-authored-by:.*\b[Aa]ider\b' + 'Co-authored-by:.*\b[Dd]evin\b' + 'Co-authored-by:.*\b[Ww]indsurf\b' + 'Co-authored-by:.*\b[Cc]line\b' + 'Co-authored-by:.*\b[Aa]mazon Q\b' + 'Co-authored-by:.*\b[Jj]ules\b' + 'Co-authored-by:.*\bOpenCode\b' +) + +# Build a single alternation regex from all patterns. +regex="" +for pattern in "${AGENT_PATTERNS[@]}"; do + if [[ -n "$regex" ]]; then + regex="${regex}|${pattern}" + else + regex="$pattern" + fi +done + +# Collect Co-authored-by lines from every commit in the PR range. +violations="" +while IFS= read -r sha; do + message="$(git log -1 --format='%B' "$sha")" + matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)" + if [[ -z "$matched_lines" ]]; then + continue + fi + + while IFS= read -r line; do + if echo "$line" | grep -iqE "$regex"; then + short="$(git log -1 --format='%h' "$sha")" + violations="${violations} ${short}: ${line}"$'\n' + fi + done <<< "$matched_lines" +done < <(git rev-list "${base_sha}..${head_sha}") + +if [[ -n "$violations" ]]; then + echo "::error::AI agent Co-authored-by trailers detected in PR commits." + echo "" + echo "The following commits contain Co-authored-by trailers from AI coding agents:" + echo "" + echo "$violations" + echo "These trailers should be removed before merging." + echo "" + echo "To fix, rewrite the commit messages with:" + echo " git rebase -i ${base_sha}" + echo "" + echo "and remove the Co-authored-by lines, then force-push your branch." + echo "" + echo "If you believe this is a false positive, please open an issue." + exit 1 +fi + +echo "No AI agent Co-authored-by trailers found." diff --git a/.github/workflows/check-ai-co-authors.yml b/.github/workflows/check-ai-co-authors.yml new file mode 100644 index 000000000..2ad9ac972 --- /dev/null +++ b/.github/workflows/check-ai-co-authors.yml @@ -0,0 +1,19 @@ +name: Check AI Co-Authors + +on: + pull_request: + branches: ['*'] + +jobs: + check-ai-co-authors: + name: Check for AI agent co-author trailers + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Check commits for AI co-author trailers + run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}" diff --git a/alembic_db/env.py b/alembic_db/env.py index 4d7770679..4ce37c012 100644 --- a/alembic_db/env.py +++ b/alembic_db/env.py @@ -8,7 +8,7 @@ from alembic import context config = context.config -from app.database.models import Base +from app.database.models import Base, NAMING_CONVENTION target_metadata = Base.metadata # other values from the config, defined by the needs of env.py, @@ -51,7 +51,10 @@ def run_migrations_online() -> None: with connectable.connect() as connection: context.configure( - connection=connection, target_metadata=target_metadata + connection=connection, + target_metadata=target_metadata, + render_as_batch=True, + naming_convention=NAMING_CONVENTION, ) with context.begin_transaction(): diff --git a/alembic_db/versions/0003_add_metadata_job_id.py b/alembic_db/versions/0003_add_metadata_job_id.py new file mode 100644 index 000000000..2a14ee924 --- /dev/null +++ b/alembic_db/versions/0003_add_metadata_job_id.py @@ -0,0 +1,98 @@ +""" +Add system_metadata and job_id columns to asset_references. +Change preview_id FK from assets.id to asset_references.id. + +Revision ID: 0003_add_metadata_job_id +Revises: 0002_merge_to_asset_references +Create Date: 2026-03-09 +""" + +from alembic import op +import sqlalchemy as sa + +from app.database.models import NAMING_CONVENTION + +revision = "0003_add_metadata_job_id" +down_revision = "0002_merge_to_asset_references" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with op.batch_alter_table("asset_references") as batch_op: + batch_op.add_column( + sa.Column("system_metadata", sa.JSON(), nullable=True) + ) + batch_op.add_column( + sa.Column("job_id", sa.String(length=36), nullable=True) + ) + + # Change preview_id FK from assets.id to asset_references.id (self-ref). + # Existing values are asset-content IDs that won't match reference IDs, + # so null them out first. + op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL") + with op.batch_alter_table( + "asset_references", naming_convention=NAMING_CONVENTION + ) as batch_op: + batch_op.drop_constraint( + "fk_asset_references_preview_id_assets", type_="foreignkey" + ) + batch_op.create_foreign_key( + "fk_asset_references_preview_id_asset_references", + "asset_references", + ["preview_id"], + ["id"], + ondelete="SET NULL", + ) + batch_op.create_index( + "ix_asset_references_preview_id", ["preview_id"] + ) + + # Purge any all-null meta rows before adding the constraint + op.execute( + "DELETE FROM asset_reference_meta" + " WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL" + ) + with op.batch_alter_table("asset_reference_meta") as batch_op: + batch_op.create_check_constraint( + "ck_asset_reference_meta_has_value", + "val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL", + ) + + +def downgrade() -> None: + # SQLite doesn't reflect CHECK constraints, so we must declare it + # explicitly via table_args for the batch recreate to find it. + # Use the fully-rendered constraint name to avoid the naming convention + # doubling the prefix. + with op.batch_alter_table( + "asset_reference_meta", + table_args=[ + sa.CheckConstraint( + "val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL", + name="ck_asset_reference_meta_has_value", + ), + ], + ) as batch_op: + batch_op.drop_constraint( + "ck_asset_reference_meta_has_value", type_="check" + ) + + with op.batch_alter_table( + "asset_references", naming_convention=NAMING_CONVENTION + ) as batch_op: + batch_op.drop_index("ix_asset_references_preview_id") + batch_op.drop_constraint( + "fk_asset_references_preview_id_asset_references", type_="foreignkey" + ) + batch_op.create_foreign_key( + "fk_asset_references_preview_id_assets", + "assets", + ["preview_id"], + ["id"], + ondelete="SET NULL", + ) + + with op.batch_alter_table("asset_references") as batch_op: + batch_op.drop_column("job_id") + batch_op.drop_column("system_metadata") diff --git a/app/assets/api/routes.py b/app/assets/api/routes.py index 40dee9f46..68126b6a5 100644 --- a/app/assets/api/routes.py +++ b/app/assets/api/routes.py @@ -13,6 +13,7 @@ from pydantic import ValidationError import folder_paths from app import user_manager from app.assets.api import schemas_in, schemas_out +from app.assets.services import schemas from app.assets.api.schemas_in import ( AssetValidationError, UploadError, @@ -38,6 +39,7 @@ from app.assets.services import ( update_asset_metadata, upload_from_temp_path, ) +from app.assets.services.tagging import list_tag_histogram ROUTES = web.RouteTableDef() USER_MANAGER: user_manager.UserManager | None = None @@ -122,6 +124,61 @@ def _validate_sort_field(requested: str | None) -> str: return "created_at" +def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None: + """Build a /api/view preview URL from asset tags and user_metadata filename.""" + if not user_metadata: + return None + filename = user_metadata.get("filename") + if not filename: + return None + + if "input" in tags: + view_type = "input" + elif "output" in tags: + view_type = "output" + else: + return None + + subfolder = "" + if "/" in filename: + subfolder, filename = filename.rsplit("/", 1) + + encoded_filename = urllib.parse.quote(filename, safe="") + url = f"/api/view?type={view_type}&filename={encoded_filename}" + if subfolder: + url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}" + return url + + +def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset: + """Build an Asset response from a service result.""" + if result.ref.preview_id: + preview_detail = get_asset_detail(result.ref.preview_id) + if preview_detail: + preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata) + else: + preview_url = None + else: + preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata) + return schemas_out.Asset( + id=result.ref.id, + name=result.ref.name, + asset_hash=result.asset.hash if result.asset else None, + size=int(result.asset.size_bytes) if result.asset else None, + mime_type=result.asset.mime_type if result.asset else None, + tags=result.tags, + preview_url=preview_url, + preview_id=result.ref.preview_id, + user_metadata=result.ref.user_metadata or {}, + metadata=result.ref.system_metadata, + job_id=result.ref.job_id, + prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat + created_at=result.ref.created_at, + updated_at=result.ref.updated_at, + last_access_time=result.ref.last_access_time, + ) + + @ROUTES.head("/api/assets/hash/{hash}") @_require_assets_feature_enabled async def head_asset_by_hash(request: web.Request) -> web.Response: @@ -164,20 +221,7 @@ async def list_assets_route(request: web.Request) -> web.Response: order=order, ) - summaries = [ - schemas_out.AssetSummary( - id=item.ref.id, - name=item.ref.name, - asset_hash=item.asset.hash if item.asset else None, - size=int(item.asset.size_bytes) if item.asset else None, - mime_type=item.asset.mime_type if item.asset else None, - tags=item.tags, - created_at=item.ref.created_at, - updated_at=item.ref.updated_at, - last_access_time=item.ref.last_access_time, - ) - for item in result.items - ] + summaries = [_build_asset_response(item) for item in result.items] payload = schemas_out.AssetsList( assets=summaries, @@ -207,18 +251,7 @@ async def get_asset_route(request: web.Request) -> web.Response: {"id": reference_id}, ) - payload = schemas_out.AssetDetail( - id=result.ref.id, - name=result.ref.name, - asset_hash=result.asset.hash if result.asset else None, - size=int(result.asset.size_bytes) if result.asset else None, - mime_type=result.asset.mime_type if result.asset else None, - tags=result.tags, - user_metadata=result.ref.user_metadata or {}, - preview_id=result.ref.preview_id, - created_at=result.ref.created_at, - last_access_time=result.ref.last_access_time, - ) + payload = _build_asset_response(result) except ValueError as e: return _build_error_response( 404, "ASSET_NOT_FOUND", str(e), {"id": reference_id} @@ -230,7 +263,7 @@ async def get_asset_route(request: web.Request) -> web.Response: USER_MANAGER.get_request_user_id(request), ) return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(payload.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200) @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") @@ -312,32 +345,31 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response: 400, "INVALID_JSON", "Request body must be valid JSON." ) + # Derive name from hash if not provided + name = body.name + if name is None: + name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash + result = create_from_hash( hash_str=body.hash, - name=body.name, + name=name, tags=body.tags, user_metadata=body.user_metadata, owner_id=USER_MANAGER.get_request_user_id(request), + mime_type=body.mime_type, + preview_id=body.preview_id, ) if result is None: return _build_error_response( 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" ) + asset = _build_asset_response(result) payload_out = schemas_out.AssetCreated( - id=result.ref.id, - name=result.ref.name, - asset_hash=result.asset.hash, - size=int(result.asset.size_bytes), - mime_type=result.asset.mime_type, - tags=result.tags, - user_metadata=result.ref.user_metadata or {}, - preview_id=result.ref.preview_id, - created_at=result.ref.created_at, - last_access_time=result.ref.last_access_time, + **asset.model_dump(), created_new=result.created_new, ) - return web.json_response(payload_out.model_dump(mode="json"), status=201) + return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201) @ROUTES.post("/api/assets") @@ -358,6 +390,8 @@ async def upload_asset(request: web.Request) -> web.Response: "name": parsed.provided_name, "user_metadata": parsed.user_metadata_raw, "hash": parsed.provided_hash, + "mime_type": parsed.provided_mime_type, + "preview_id": parsed.provided_preview_id, } ) except ValidationError as ve: @@ -386,6 +420,8 @@ async def upload_asset(request: web.Request) -> web.Response: tags=spec.tags, user_metadata=spec.user_metadata or {}, owner_id=owner_id, + mime_type=spec.mime_type, + preview_id=spec.preview_id, ) if result is None: delete_temp_file_if_exists(parsed.tmp_path) @@ -410,6 +446,8 @@ async def upload_asset(request: web.Request) -> web.Response: client_filename=parsed.file_client_name, owner_id=owner_id, expected_hash=spec.hash, + mime_type=spec.mime_type, + preview_id=spec.preview_id, ) except AssetValidationError as e: delete_temp_file_if_exists(parsed.tmp_path) @@ -428,21 +466,13 @@ async def upload_asset(request: web.Request) -> web.Response: logging.exception("upload_asset failed for owner_id=%s", owner_id) return _build_error_response(500, "INTERNAL", "Unexpected server error.") - payload = schemas_out.AssetCreated( - id=result.ref.id, - name=result.ref.name, - asset_hash=result.asset.hash, - size=int(result.asset.size_bytes), - mime_type=result.asset.mime_type, - tags=result.tags, - user_metadata=result.ref.user_metadata or {}, - preview_id=result.ref.preview_id, - created_at=result.ref.created_at, - last_access_time=result.ref.last_access_time, + asset = _build_asset_response(result) + payload_out = schemas_out.AssetCreated( + **asset.model_dump(), created_new=result.created_new, ) status = 201 if result.created_new else 200 - return web.json_response(payload.model_dump(mode="json"), status=status) + return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status) @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") @@ -464,15 +494,9 @@ async def update_asset_route(request: web.Request) -> web.Response: name=body.name, user_metadata=body.user_metadata, owner_id=USER_MANAGER.get_request_user_id(request), + preview_id=body.preview_id, ) - payload = schemas_out.AssetUpdated( - id=result.ref.id, - name=result.ref.name, - asset_hash=result.asset.hash if result.asset else None, - tags=result.tags, - user_metadata=result.ref.user_metadata or {}, - updated_at=result.ref.updated_at, - ) + payload = _build_asset_response(result) except PermissionError as pe: return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) except ValueError as ve: @@ -486,7 +510,7 @@ async def update_asset_route(request: web.Request) -> web.Response: USER_MANAGER.get_request_user_id(request), ) return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(payload.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200) @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") @@ -555,7 +579,7 @@ async def get_tags(request: web.Request) -> web.Response: payload = schemas_out.TagsList( tags=tags, total=total, has_more=(query.offset + len(tags)) < total ) - return web.json_response(payload.model_dump(mode="json")) + return web.json_response(payload.model_dump(mode="json", exclude_none=True)) @ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") @@ -603,7 +627,7 @@ async def add_asset_tags(request: web.Request) -> web.Response: ) return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(payload.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200) @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") @@ -650,7 +674,29 @@ async def delete_asset_tags(request: web.Request) -> web.Response: ) return _build_error_response(500, "INTERNAL", "Unexpected server error.") - return web.json_response(payload.model_dump(mode="json"), status=200) + return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200) + + +@ROUTES.get("/api/assets/tags/refine") +@_require_assets_feature_enabled +async def get_tags_refine(request: web.Request) -> web.Response: + """GET request to get tag histogram for filtered assets.""" + query_dict = get_query_dict(request) + try: + q = schemas_in.TagsRefineQuery.model_validate(query_dict) + except ValidationError as ve: + return _build_validation_error_response("INVALID_QUERY", ve) + + tag_counts = list_tag_histogram( + owner_id=USER_MANAGER.get_request_user_id(request), + include_tags=q.include_tags, + exclude_tags=q.exclude_tags, + name_contains=q.name_contains, + metadata_filter=q.metadata_filter, + limit=q.limit, + ) + payload = schemas_out.TagHistogram(tag_counts=tag_counts) + return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200) @ROUTES.post("/api/assets/seed") diff --git a/app/assets/api/schemas_in.py b/app/assets/api/schemas_in.py index d255c938e..186a6ae1e 100644 --- a/app/assets/api/schemas_in.py +++ b/app/assets/api/schemas_in.py @@ -45,6 +45,8 @@ class ParsedUpload: user_metadata_raw: str | None provided_hash: str | None provided_hash_exists: bool | None + provided_mime_type: str | None = None + provided_preview_id: str | None = None class ListAssetsQuery(BaseModel): @@ -98,11 +100,17 @@ class ListAssetsQuery(BaseModel): class UpdateAssetBody(BaseModel): name: str | None = None user_metadata: dict[str, Any] | None = None + preview_id: str | None = None # references an asset_reference id, not an asset id @model_validator(mode="after") def _validate_at_least_one_field(self): - if self.name is None and self.user_metadata is None: - raise ValueError("Provide at least one of: name, user_metadata.") + if all( + v is None + for v in (self.name, self.user_metadata, self.preview_id) + ): + raise ValueError( + "Provide at least one of: name, user_metadata, preview_id." + ) return self @@ -110,9 +118,11 @@ class CreateFromHashBody(BaseModel): model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) hash: str - name: str + name: str | None = None tags: list[str] = Field(default_factory=list) user_metadata: dict[str, Any] = Field(default_factory=dict) + mime_type: str | None = None + preview_id: str | None = None # references an asset_reference id, not an asset id @field_validator("hash") @classmethod @@ -138,6 +148,44 @@ class CreateFromHashBody(BaseModel): return [] +class TagsRefineQuery(BaseModel): + include_tags: list[str] = Field(default_factory=list) + exclude_tags: list[str] = Field(default_factory=list) + name_contains: str | None = None + metadata_filter: dict[str, Any] | None = None + limit: conint(ge=1, le=1000) = 100 + + @field_validator("include_tags", "exclude_tags", mode="before") + @classmethod + def _split_csv_tags(cls, v): + if v is None: + return [] + if isinstance(v, str): + return [t.strip() for t in v.split(",") if t.strip()] + if isinstance(v, list): + out: list[str] = [] + for item in v: + if isinstance(item, str): + out.extend([t.strip() for t in item.split(",") if t.strip()]) + return out + return v + + @field_validator("metadata_filter", mode="before") + @classmethod + def _parse_metadata_json(cls, v): + if v is None or isinstance(v, dict): + return v + if isinstance(v, str) and v.strip(): + try: + parsed = json.loads(v) + except Exception as e: + raise ValueError(f"metadata_filter must be JSON: {e}") from e + if not isinstance(parsed, dict): + raise ValueError("metadata_filter must be a JSON object") + return parsed + return None + + class TagsListQuery(BaseModel): model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) @@ -186,21 +234,25 @@ class TagsRemove(TagsAdd): class UploadAssetSpec(BaseModel): """Upload Asset operation. - - tags: ordered; first is root ('models'|'input'|'output'); + - tags: optional list; if provided, first is root ('models'|'input'|'output'); if root == 'models', second must be a valid category - name: display name - user_metadata: arbitrary JSON object (optional) - hash: optional canonical 'blake3:' for validation / fast-path + - mime_type: optional MIME type override + - preview_id: optional asset_reference ID for preview Files are stored using the content hash as filename stem. """ model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) - tags: list[str] = Field(..., min_length=1) + tags: list[str] = Field(default_factory=list) name: str | None = Field(default=None, max_length=512, description="Display Name") user_metadata: dict[str, Any] = Field(default_factory=dict) hash: str | None = Field(default=None) + mime_type: str | None = Field(default=None) + preview_id: str | None = Field(default=None) # references an asset_reference id @field_validator("hash", mode="before") @classmethod @@ -279,7 +331,7 @@ class UploadAssetSpec(BaseModel): @model_validator(mode="after") def _validate_order(self): if not self.tags: - raise ValueError("tags must be provided and non-empty") + raise ValueError("at least one tag is required for uploads") root = self.tags[0] if root not in {"models", "input", "output"}: raise ValueError("first tag must be one of: models, input, output") diff --git a/app/assets/api/schemas_out.py b/app/assets/api/schemas_out.py index f36447856..d99b1098d 100644 --- a/app/assets/api/schemas_out.py +++ b/app/assets/api/schemas_out.py @@ -4,7 +4,10 @@ from typing import Any from pydantic import BaseModel, ConfigDict, Field, field_serializer -class AssetSummary(BaseModel): +class Asset(BaseModel): + """API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob; + ``id`` here is the AssetReference id, not the content-addressed Asset id.""" + id: str name: str asset_hash: str | None = None @@ -12,8 +15,14 @@ class AssetSummary(BaseModel): mime_type: str | None = None tags: list[str] = Field(default_factory=list) preview_url: str | None = None - created_at: datetime | None = None - updated_at: datetime | None = None + preview_id: str | None = None # references an asset_reference id, not an asset id + user_metadata: dict[str, Any] = Field(default_factory=dict) + is_immutable: bool = False + metadata: dict[str, Any] | None = None + job_id: str | None = None + prompt_id: str | None = None # deprecated: use job_id + created_at: datetime + updated_at: datetime last_access_time: datetime | None = None model_config = ConfigDict(from_attributes=True) @@ -23,50 +32,16 @@ class AssetSummary(BaseModel): return v.isoformat() if v else None +class AssetCreated(Asset): + created_new: bool + + class AssetsList(BaseModel): - assets: list[AssetSummary] + assets: list[Asset] total: int has_more: bool -class AssetUpdated(BaseModel): - id: str - name: str - asset_hash: str | None = None - tags: list[str] = Field(default_factory=list) - user_metadata: dict[str, Any] = Field(default_factory=dict) - updated_at: datetime | None = None - - model_config = ConfigDict(from_attributes=True) - - @field_serializer("updated_at") - def _serialize_updated_at(self, v: datetime | None, _info): - return v.isoformat() if v else None - - -class AssetDetail(BaseModel): - id: str - name: str - asset_hash: str | None = None - size: int | None = None - mime_type: str | None = None - tags: list[str] = Field(default_factory=list) - user_metadata: dict[str, Any] = Field(default_factory=dict) - preview_id: str | None = None - created_at: datetime | None = None - last_access_time: datetime | None = None - - model_config = ConfigDict(from_attributes=True) - - @field_serializer("created_at", "last_access_time") - def _serialize_datetime(self, v: datetime | None, _info): - return v.isoformat() if v else None - - -class AssetCreated(AssetDetail): - created_new: bool - - class TagUsage(BaseModel): name: str count: int @@ -91,3 +66,7 @@ class TagsRemove(BaseModel): removed: list[str] = Field(default_factory=list) not_present: list[str] = Field(default_factory=list) total_tags: list[str] = Field(default_factory=list) + + +class TagHistogram(BaseModel): + tag_counts: dict[str, int] diff --git a/app/assets/api/upload.py b/app/assets/api/upload.py index 721c12f4d..13d3d372c 100644 --- a/app/assets/api/upload.py +++ b/app/assets/api/upload.py @@ -52,6 +52,8 @@ async def parse_multipart_upload( user_metadata_raw: str | None = None provided_hash: str | None = None provided_hash_exists: bool | None = None + provided_mime_type: str | None = None + provided_preview_id: str | None = None file_written = 0 tmp_path: str | None = None @@ -128,6 +130,16 @@ async def parse_multipart_upload( provided_name = (await field.text()) or None elif fname == "user_metadata": user_metadata_raw = (await field.text()) or None + elif fname == "id": + raise UploadError( + 400, + "UNSUPPORTED_FIELD", + "Client-provided 'id' is not supported. Asset IDs are assigned by the server.", + ) + elif fname == "mime_type": + provided_mime_type = ((await field.text()) or "").strip() or None + elif fname == "preview_id": + provided_preview_id = ((await field.text()) or "").strip() or None if not file_present and not (provided_hash and provided_hash_exists): raise UploadError( @@ -152,6 +164,8 @@ async def parse_multipart_upload( user_metadata_raw=user_metadata_raw, provided_hash=provided_hash, provided_hash_exists=provided_hash_exists, + provided_mime_type=provided_mime_type, + provided_preview_id=provided_preview_id, ) diff --git a/app/assets/database/models.py b/app/assets/database/models.py index 03c1c1707..a3af8a192 100644 --- a/app/assets/database/models.py +++ b/app/assets/database/models.py @@ -45,13 +45,7 @@ class Asset(Base): passive_deletes=True, ) - preview_of: Mapped[list[AssetReference]] = relationship( - "AssetReference", - back_populates="preview_asset", - primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id), - foreign_keys=lambda: [AssetReference.preview_id], - viewonly=True, - ) + # preview_id on AssetReference is a self-referential FK to asset_references.id __table_args__ = ( Index("uq_assets_hash", "hash", unique=True), @@ -91,11 +85,15 @@ class AssetReference(Base): owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") name: Mapped[str] = mapped_column(String(512), nullable=False) preview_id: Mapped[str | None] = mapped_column( - String(36), ForeignKey("assets.id", ondelete="SET NULL") + String(36), ForeignKey("asset_references.id", ondelete="SET NULL") ) user_metadata: Mapped[dict[str, Any] | None] = mapped_column( JSON(none_as_null=True) ) + system_metadata: Mapped[dict[str, Any] | None] = mapped_column( + JSON(none_as_null=True), nullable=True, default=None + ) + job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None) created_at: Mapped[datetime] = mapped_column( DateTime(timezone=False), nullable=False, default=get_utc_now ) @@ -115,10 +113,10 @@ class AssetReference(Base): foreign_keys=[asset_id], lazy="selectin", ) - preview_asset: Mapped[Asset | None] = relationship( - "Asset", - back_populates="preview_of", + preview_ref: Mapped[AssetReference | None] = relationship( + "AssetReference", foreign_keys=[preview_id], + remote_side=lambda: [AssetReference.id], ) metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship( @@ -152,6 +150,7 @@ class AssetReference(Base): Index("ix_asset_references_created_at", "created_at"), Index("ix_asset_references_last_access_time", "last_access_time"), Index("ix_asset_references_deleted_at", "deleted_at"), + Index("ix_asset_references_preview_id", "preview_id"), Index("ix_asset_references_owner_name", "owner_id", "name"), CheckConstraint( "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" @@ -192,6 +191,10 @@ class AssetReferenceMeta(Base): Index("ix_asset_reference_meta_key_val_str", "key", "val_str"), Index("ix_asset_reference_meta_key_val_num", "key", "val_num"), Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"), + CheckConstraint( + "val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL", + name="has_value", + ), ) diff --git a/app/assets/database/queries/__init__.py b/app/assets/database/queries/__init__.py index 7888d0645..1632937b2 100644 --- a/app/assets/database/queries/__init__.py +++ b/app/assets/database/queries/__init__.py @@ -31,16 +31,21 @@ from app.assets.database.queries.asset_reference import ( get_unenriched_references, get_unreferenced_unhashed_asset_ids, insert_reference, + list_all_file_paths_by_asset_id, list_references_by_asset_id, list_references_page, mark_references_missing_outside_prefixes, + rebuild_metadata_projection, + reference_exists, reference_exists_for_asset_id, restore_references_by_paths, set_reference_metadata, set_reference_preview, + set_reference_system_metadata, soft_delete_reference_by_id, update_reference_access_time, update_reference_name, + update_is_missing_by_asset_id, update_reference_timestamps, update_reference_updated_at, upsert_reference, @@ -54,6 +59,7 @@ from app.assets.database.queries.tags import ( bulk_insert_tags_and_meta, ensure_tags_exist, get_reference_tags, + list_tag_counts_for_filtered_assets, list_tags_with_usage, remove_missing_tag_for_asset_id, remove_tags_from_reference, @@ -97,20 +103,26 @@ __all__ = [ "get_unenriched_references", "get_unreferenced_unhashed_asset_ids", "insert_reference", + "list_all_file_paths_by_asset_id", "list_references_by_asset_id", "list_references_page", + "list_tag_counts_for_filtered_assets", "list_tags_with_usage", "mark_references_missing_outside_prefixes", "reassign_asset_references", + "rebuild_metadata_projection", + "reference_exists", "reference_exists_for_asset_id", "remove_missing_tag_for_asset_id", "remove_tags_from_reference", "restore_references_by_paths", "set_reference_metadata", "set_reference_preview", + "set_reference_system_metadata", "soft_delete_reference_by_id", "set_reference_tags", "update_asset_hash_and_mime", + "update_is_missing_by_asset_id", "update_reference_access_time", "update_reference_name", "update_reference_timestamps", diff --git a/app/assets/database/queries/asset.py b/app/assets/database/queries/asset.py index a21f5b68f..594d1f1b2 100644 --- a/app/assets/database/queries/asset.py +++ b/app/assets/database/queries/asset.py @@ -69,7 +69,7 @@ def upsert_asset( if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: asset.size_bytes = int(size_bytes) changed = True - if mime_type and asset.mime_type != mime_type: + if mime_type and not asset.mime_type: asset.mime_type = mime_type changed = True if changed: @@ -118,7 +118,7 @@ def update_asset_hash_and_mime( return False if asset_hash is not None: asset.hash = asset_hash - if mime_type is not None: + if mime_type is not None and not asset.mime_type: asset.mime_type = mime_type return True diff --git a/app/assets/database/queries/asset_reference.py b/app/assets/database/queries/asset_reference.py index 6524791cc..084a32512 100644 --- a/app/assets/database/queries/asset_reference.py +++ b/app/assets/database/queries/asset_reference.py @@ -10,7 +10,7 @@ from decimal import Decimal from typing import NamedTuple, Sequence import sqlalchemy as sa -from sqlalchemy import delete, exists, select +from sqlalchemy import delete, select from sqlalchemy.dialects import sqlite from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session, noload @@ -24,12 +24,14 @@ from app.assets.database.models import ( ) from app.assets.database.queries.common import ( MAX_BIND_PARAMS, + apply_metadata_filter, + apply_tag_filters, build_prefix_like_conditions, build_visible_owner_clause, calculate_rows_per_statement, iter_chunks, ) -from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags +from app.assets.helpers import escape_sql_like_string, get_utc_now def _check_is_scalar(v): @@ -44,15 +46,6 @@ def _check_is_scalar(v): def _scalar_to_row(key: str, ordinal: int, value) -> dict: """Convert a scalar value to a typed projection row.""" - if value is None: - return { - "key": key, - "ordinal": ordinal, - "val_str": None, - "val_num": None, - "val_bool": None, - "val_json": None, - } if isinstance(value, bool): return {"key": key, "ordinal": ordinal, "val_bool": bool(value)} if isinstance(value, (int, float, Decimal)): @@ -66,96 +59,19 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict: def convert_metadata_to_rows(key: str, value) -> list[dict]: """Turn a metadata key/value into typed projection rows.""" if value is None: - return [_scalar_to_row(key, 0, None)] + return [] if _check_is_scalar(value): return [_scalar_to_row(key, 0, value)] if isinstance(value, list): if all(_check_is_scalar(x) for x in value): - return [_scalar_to_row(key, i, x) for i, x in enumerate(value)] - return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)] + return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None] + return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None] return [{"key": key, "ordinal": 0, "val_json": value}] -def _apply_tag_filters( - stmt: sa.sql.Select, - include_tags: Sequence[str] | None = None, - exclude_tags: Sequence[str] | None = None, -) -> sa.sql.Select: - """include_tags: every tag must be present; exclude_tags: none may be present.""" - include_tags = normalize_tags(include_tags) - exclude_tags = normalize_tags(exclude_tags) - - if include_tags: - for tag_name in include_tags: - stmt = stmt.where( - exists().where( - (AssetReferenceTag.asset_reference_id == AssetReference.id) - & (AssetReferenceTag.tag_name == tag_name) - ) - ) - - if exclude_tags: - stmt = stmt.where( - ~exists().where( - (AssetReferenceTag.asset_reference_id == AssetReference.id) - & (AssetReferenceTag.tag_name.in_(exclude_tags)) - ) - ) - return stmt - - -def _apply_metadata_filter( - stmt: sa.sql.Select, - metadata_filter: dict | None = None, -) -> sa.sql.Select: - """Apply filters using asset_reference_meta projection table.""" - if not metadata_filter: - return stmt - - def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: - return sa.exists().where( - AssetReferenceMeta.asset_reference_id == AssetReference.id, - AssetReferenceMeta.key == key, - *preds, - ) - - def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: - if value is None: - no_row_for_key = sa.not_( - sa.exists().where( - AssetReferenceMeta.asset_reference_id == AssetReference.id, - AssetReferenceMeta.key == key, - ) - ) - null_row = _exists_for_pred( - key, - AssetReferenceMeta.val_json.is_(None), - AssetReferenceMeta.val_str.is_(None), - AssetReferenceMeta.val_num.is_(None), - AssetReferenceMeta.val_bool.is_(None), - ) - return sa.or_(no_row_for_key, null_row) - - if isinstance(value, bool): - return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value)) - if isinstance(value, (int, float, Decimal)): - num = value if isinstance(value, Decimal) else Decimal(str(value)) - return _exists_for_pred(key, AssetReferenceMeta.val_num == num) - if isinstance(value, str): - return _exists_for_pred(key, AssetReferenceMeta.val_str == value) - return _exists_for_pred(key, AssetReferenceMeta.val_json == value) - - for k, v in metadata_filter.items(): - if isinstance(v, list): - ors = [_exists_clause_for_value(k, elem) for elem in v] - if ors: - stmt = stmt.where(sa.or_(*ors)) - else: - stmt = stmt.where(_exists_clause_for_value(k, v)) - return stmt def get_reference_by_id( @@ -212,6 +128,21 @@ def reference_exists_for_asset_id( return session.execute(q).first() is not None +def reference_exists( + session: Session, + reference_id: str, +) -> bool: + """Return True if a reference with the given ID exists (not soft-deleted).""" + q = ( + select(sa.literal(True)) + .select_from(AssetReference) + .where(AssetReference.id == reference_id) + .where(AssetReference.deleted_at.is_(None)) + .limit(1) + ) + return session.execute(q).first() is not None + + def insert_reference( session: Session, asset_id: str, @@ -336,8 +267,8 @@ def list_references_page( escaped, esc = escape_sql_like_string(name_contains) base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) - base = _apply_tag_filters(base, include_tags, exclude_tags) - base = _apply_metadata_filter(base, metadata_filter) + base = apply_tag_filters(base, include_tags, exclude_tags) + base = apply_metadata_filter(base, metadata_filter) sort = (sort or "created_at").lower() order = (order or "desc").lower() @@ -366,8 +297,8 @@ def list_references_page( count_stmt = count_stmt.where( AssetReference.name.ilike(f"%{escaped}%", escape=esc) ) - count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) - count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) + count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags) + count_stmt = apply_metadata_filter(count_stmt, metadata_filter) total = int(session.execute(count_stmt).scalar_one() or 0) refs = session.execute(base).unique().scalars().all() @@ -379,7 +310,7 @@ def list_references_page( select(AssetReferenceTag.asset_reference_id, Tag.name) .join(Tag, Tag.name == AssetReferenceTag.tag_name) .where(AssetReferenceTag.asset_reference_id.in_(id_list)) - .order_by(AssetReferenceTag.added_at) + .order_by(AssetReferenceTag.tag_name.asc()) ) for ref_id, tag_name in rows.all(): tag_map[ref_id].append(tag_name) @@ -492,6 +423,42 @@ def update_reference_updated_at( ) +def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None: + """Delete and rebuild AssetReferenceMeta rows from merged system+user metadata. + + The merged dict is ``{**system_metadata, **user_metadata}`` so user keys + override system keys of the same name. + """ + session.execute( + delete(AssetReferenceMeta).where( + AssetReferenceMeta.asset_reference_id == ref.id + ) + ) + session.flush() + + merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})} + if not merged: + return + + rows: list[AssetReferenceMeta] = [] + for k, v in merged.items(): + for r in convert_metadata_to_rows(k, v): + rows.append( + AssetReferenceMeta( + asset_reference_id=ref.id, + key=r["key"], + ordinal=int(r["ordinal"]), + val_str=r.get("val_str"), + val_num=r.get("val_num"), + val_bool=r.get("val_bool"), + val_json=r.get("val_json"), + ) + ) + if rows: + session.add_all(rows) + session.flush() + + def set_reference_metadata( session: Session, reference_id: str, @@ -505,33 +472,24 @@ def set_reference_metadata( ref.updated_at = get_utc_now() session.flush() - session.execute( - delete(AssetReferenceMeta).where( - AssetReferenceMeta.asset_reference_id == reference_id - ) - ) + rebuild_metadata_projection(session, ref) + + +def set_reference_system_metadata( + session: Session, + reference_id: str, + system_metadata: dict | None = None, +) -> None: + """Set system_metadata on a reference and rebuild the merged projection.""" + ref = session.get(AssetReference, reference_id) + if not ref: + raise ValueError(f"AssetReference {reference_id} not found") + + ref.system_metadata = system_metadata or {} + ref.updated_at = get_utc_now() session.flush() - if not user_metadata: - return - - rows: list[AssetReferenceMeta] = [] - for k, v in user_metadata.items(): - for r in convert_metadata_to_rows(k, v): - rows.append( - AssetReferenceMeta( - asset_reference_id=reference_id, - key=r["key"], - ordinal=int(r["ordinal"]), - val_str=r.get("val_str"), - val_num=r.get("val_num"), - val_bool=r.get("val_bool"), - val_json=r.get("val_json"), - ) - ) - if rows: - session.add_all(rows) - session.flush() + rebuild_metadata_projection(session, ref) def delete_reference_by_id( @@ -571,19 +529,19 @@ def soft_delete_reference_by_id( def set_reference_preview( session: Session, reference_id: str, - preview_asset_id: str | None = None, + preview_reference_id: str | None = None, ) -> None: """Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" ref = session.get(AssetReference, reference_id) if not ref: raise ValueError(f"AssetReference {reference_id} not found") - if preview_asset_id is None: + if preview_reference_id is None: ref.preview_id = None else: - if not session.get(Asset, preview_asset_id): - raise ValueError(f"Preview Asset {preview_asset_id} not found") - ref.preview_id = preview_asset_id + if not session.get(AssetReference, preview_reference_id): + raise ValueError(f"Preview AssetReference {preview_reference_id} not found") + ref.preview_id = preview_reference_id ref.updated_at = get_utc_now() session.flush() @@ -609,6 +567,8 @@ def list_references_by_asset_id( session.execute( select(AssetReference) .where(AssetReference.asset_id == asset_id) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) .order_by(AssetReference.id.asc()) ) .scalars() @@ -616,6 +576,25 @@ def list_references_by_asset_id( ) +def list_all_file_paths_by_asset_id( + session: Session, + asset_id: str, +) -> list[str]: + """Return every file_path for an asset, including soft-deleted/missing refs. + + Used for orphan cleanup where all on-disk files must be removed. + """ + return list( + session.execute( + select(AssetReference.file_path) + .where(AssetReference.asset_id == asset_id) + .where(AssetReference.file_path.isnot(None)) + ) + .scalars() + .all() + ) + + def upsert_reference( session: Session, asset_id: str, @@ -855,6 +834,22 @@ def bulk_update_is_missing( return total +def update_is_missing_by_asset_id( + session: Session, asset_id: str, value: bool +) -> int: + """Set is_missing flag for ALL references belonging to an asset. + + Returns: Number of rows updated + """ + result = session.execute( + sa.update(AssetReference) + .where(AssetReference.asset_id == asset_id) + .where(AssetReference.deleted_at.is_(None)) + .values(is_missing=value) + ) + return result.rowcount + + def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int: """Delete references by their IDs. diff --git a/app/assets/database/queries/common.py b/app/assets/database/queries/common.py index 194c39a1e..89bb49327 100644 --- a/app/assets/database/queries/common.py +++ b/app/assets/database/queries/common.py @@ -1,12 +1,14 @@ """Shared utilities for database query modules.""" import os -from typing import Iterable +from decimal import Decimal +from typing import Iterable, Sequence import sqlalchemy as sa +from sqlalchemy import exists -from app.assets.database.models import AssetReference -from app.assets.helpers import escape_sql_like_string +from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag +from app.assets.helpers import escape_sql_like_string, normalize_tags MAX_BIND_PARAMS = 800 @@ -52,3 +54,74 @@ def build_prefix_like_conditions( escaped, esc = escape_sql_like_string(base) conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) return conds + + +def apply_tag_filters( + stmt: sa.sql.Select, + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, +) -> sa.sql.Select: + """include_tags: every tag must be present; exclude_tags: none may be present.""" + include_tags = normalize_tags(include_tags) + exclude_tags = normalize_tags(exclude_tags) + + if include_tags: + for tag_name in include_tags: + stmt = stmt.where( + exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name == tag_name) + ) + ) + + if exclude_tags: + stmt = stmt.where( + ~exists().where( + (AssetReferenceTag.asset_reference_id == AssetReference.id) + & (AssetReferenceTag.tag_name.in_(exclude_tags)) + ) + ) + return stmt + + +def apply_metadata_filter( + stmt: sa.sql.Select, + metadata_filter: dict | None = None, +) -> sa.sql.Select: + """Apply filters using asset_reference_meta projection table.""" + if not metadata_filter: + return stmt + + def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement: + return sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + *preds, + ) + + def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement: + if value is None: + return sa.not_( + sa.exists().where( + AssetReferenceMeta.asset_reference_id == AssetReference.id, + AssetReferenceMeta.key == key, + ) + ) + + if isinstance(value, bool): + return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value)) + if isinstance(value, (int, float, Decimal)): + num = value if isinstance(value, Decimal) else Decimal(str(value)) + return _exists_for_pred(key, AssetReferenceMeta.val_num == num) + if isinstance(value, str): + return _exists_for_pred(key, AssetReferenceMeta.val_str == value) + return _exists_for_pred(key, AssetReferenceMeta.val_json == value) + + for k, v in metadata_filter.items(): + if isinstance(v, list): + ors = [_exists_clause_for_value(k, elem) for elem in v] + if ors: + stmt = stmt.where(sa.or_(*ors)) + else: + stmt = stmt.where(_exists_clause_for_value(k, v)) + return stmt diff --git a/app/assets/database/queries/tags.py b/app/assets/database/queries/tags.py index 8b25fee67..f4126dba8 100644 --- a/app/assets/database/queries/tags.py +++ b/app/assets/database/queries/tags.py @@ -8,12 +8,15 @@ from sqlalchemy.exc import IntegrityError from sqlalchemy.orm import Session from app.assets.database.models import ( + Asset, AssetReference, AssetReferenceMeta, AssetReferenceTag, Tag, ) from app.assets.database.queries.common import ( + apply_metadata_filter, + apply_tag_filters, build_visible_owner_clause, iter_row_chunks, ) @@ -72,9 +75,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]: tag_name for (tag_name,) in ( session.execute( - select(AssetReferenceTag.tag_name).where( - AssetReferenceTag.asset_reference_id == reference_id - ) + select(AssetReferenceTag.tag_name) + .where(AssetReferenceTag.asset_reference_id == reference_id) + .order_by(AssetReferenceTag.tag_name.asc()) ) ).all() ] @@ -117,7 +120,7 @@ def set_reference_tags( ) session.flush() - return SetTagsResult(added=to_add, removed=to_remove, total=desired) + return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired)) def add_tags_to_reference( @@ -272,6 +275,12 @@ def list_tags_with_usage( .select_from(AssetReferenceTag) .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) .where(build_visible_owner_clause(owner_id)) + .where( + sa.or_( + AssetReference.is_missing == False, # noqa: E712 + AssetReferenceTag.tag_name == "missing", + ) + ) .where(AssetReference.deleted_at.is_(None)) .group_by(AssetReferenceTag.tag_name) .subquery() @@ -308,6 +317,12 @@ def list_tags_with_usage( select(AssetReferenceTag.tag_name) .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) .where(build_visible_owner_clause(owner_id)) + .where( + sa.or_( + AssetReference.is_missing == False, # noqa: E712 + AssetReferenceTag.tag_name == "missing", + ) + ) .where(AssetReference.deleted_at.is_(None)) .group_by(AssetReferenceTag.tag_name) ) @@ -320,6 +335,53 @@ def list_tags_with_usage( return rows_norm, int(total or 0) +def list_tag_counts_for_filtered_assets( + session: Session, + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 100, +) -> dict[str, int]: + """Return tag counts for assets matching the given filters. + + Uses the same filtering logic as list_references_page but returns + {tag_name: count} instead of paginated references. + """ + # Build a subquery of matching reference IDs + ref_sq = ( + select(AssetReference.id) + .join(Asset, Asset.id == AssetReference.asset_id) + .where(build_visible_owner_clause(owner_id)) + .where(AssetReference.is_missing == False) # noqa: E712 + .where(AssetReference.deleted_at.is_(None)) + ) + + if name_contains: + escaped, esc = escape_sql_like_string(name_contains) + ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) + + ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags) + ref_sq = apply_metadata_filter(ref_sq, metadata_filter) + ref_sq = ref_sq.subquery() + + # Count tags across those references + q = ( + select( + AssetReferenceTag.tag_name, + func.count(AssetReferenceTag.asset_reference_id).label("cnt"), + ) + .where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id))) + .group_by(AssetReferenceTag.tag_name) + .order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc()) + .limit(limit) + ) + + rows = session.execute(q).all() + return {tag_name: int(cnt) for tag_name, cnt in rows} + + def bulk_insert_tags_and_meta( session: Session, tag_rows: list[dict], diff --git a/app/assets/scanner.py b/app/assets/scanner.py index e27ea5123..4e05a97b5 100644 --- a/app/assets/scanner.py +++ b/app/assets/scanner.py @@ -18,7 +18,7 @@ from app.assets.database.queries import ( mark_references_missing_outside_prefixes, reassign_asset_references, remove_missing_tag_for_asset_id, - set_reference_metadata, + set_reference_system_metadata, update_asset_hash_and_mime, ) from app.assets.services.bulk_ingest import ( @@ -490,8 +490,8 @@ def enrich_asset( logging.warning("Failed to hash %s: %s", file_path, e) if extract_metadata and metadata: - user_metadata = metadata.to_user_metadata() - set_reference_metadata(session, reference_id, user_metadata) + system_metadata = metadata.to_user_metadata() + set_reference_system_metadata(session, reference_id, system_metadata) if full_hash: existing = get_asset_by_hash(session, full_hash) diff --git a/app/assets/services/asset_management.py b/app/assets/services/asset_management.py index 3fe7115c8..5aefd9956 100644 --- a/app/assets/services/asset_management.py +++ b/app/assets/services/asset_management.py @@ -16,10 +16,12 @@ from app.assets.database.queries import ( get_reference_by_id, get_reference_with_owner_check, list_references_page, + list_all_file_paths_by_asset_id, list_references_by_asset_id, set_reference_metadata, set_reference_preview, set_reference_tags, + update_asset_hash_and_mime, update_reference_access_time, update_reference_name, update_reference_updated_at, @@ -67,6 +69,8 @@ def update_asset_metadata( user_metadata: UserMetadata = None, tag_origin: str = "manual", owner_id: str = "", + mime_type: str | None = None, + preview_id: str | None = None, ) -> AssetDetailResult: with create_session() as session: ref = get_reference_with_owner_check(session, reference_id, owner_id) @@ -103,6 +107,21 @@ def update_asset_metadata( ) touched = True + if mime_type is not None: + updated = update_asset_hash_and_mime( + session, asset_id=ref.asset_id, mime_type=mime_type + ) + if updated: + touched = True + + if preview_id is not None: + set_reference_preview( + session, + reference_id=reference_id, + preview_reference_id=preview_id, + ) + touched = True + if touched and user_metadata is None: update_reference_updated_at(session, reference_id=reference_id) @@ -159,11 +178,9 @@ def delete_asset_reference( session.commit() return True - # Orphaned asset - delete it and its files - refs = list_references_by_asset_id(session, asset_id=asset_id) - file_paths = [ - r.file_path for r in (refs or []) if getattr(r, "file_path", None) - ] + # Orphaned asset - gather ALL file paths (including + # soft-deleted / missing refs) so their on-disk files get cleaned up. + file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id) # Also include the just-deleted file path if file_path: file_paths.append(file_path) @@ -185,7 +202,7 @@ def delete_asset_reference( def set_asset_preview( reference_id: str, - preview_asset_id: str | None = None, + preview_reference_id: str | None = None, owner_id: str = "", ) -> AssetDetailResult: with create_session() as session: @@ -194,7 +211,7 @@ def set_asset_preview( set_reference_preview( session, reference_id=reference_id, - preview_asset_id=preview_asset_id, + preview_reference_id=preview_reference_id, ) result = fetch_reference_asset_and_tags( @@ -263,6 +280,47 @@ def list_assets_page( return ListAssetsResult(items=items, total=total) +def resolve_hash_to_path( + asset_hash: str, + owner_id: str = "", +) -> DownloadResolutionResult | None: + """Resolve a blake3 hash to an on-disk file path. + + Only references visible to *owner_id* are considered (owner-less + references are always visible). + + Returns a DownloadResolutionResult with abs_path, content_type, and + download_name, or None if no asset or live path is found. + """ + with create_session() as session: + asset = queries_get_asset_by_hash(session, asset_hash) + if not asset: + return None + refs = list_references_by_asset_id(session, asset_id=asset.id) + visible = [ + r for r in refs + if r.owner_id == "" or r.owner_id == owner_id + ] + abs_path = select_best_live_path(visible) + if not abs_path: + return None + display_name = os.path.basename(abs_path) + for ref in visible: + if ref.file_path == abs_path and ref.name: + display_name = ref.name + break + ctype = ( + asset.mime_type + or mimetypes.guess_type(display_name)[0] + or "application/octet-stream" + ) + return DownloadResolutionResult( + abs_path=abs_path, + content_type=ctype, + download_name=display_name, + ) + + def resolve_asset_for_download( reference_id: str, owner_id: str = "", diff --git a/app/assets/services/ingest.py b/app/assets/services/ingest.py index 44d7aef36..90c51994f 100644 --- a/app/assets/services/ingest.py +++ b/app/assets/services/ingest.py @@ -11,13 +11,14 @@ from app.assets.database.queries import ( add_tags_to_reference, fetch_reference_and_asset, get_asset_by_hash, - get_existing_asset_ids, get_reference_by_file_path, get_reference_tags, get_or_create_reference, + reference_exists, remove_missing_tag_for_asset_id, set_reference_metadata, set_reference_tags, + update_asset_hash_and_mime, upsert_asset, upsert_reference, validate_tags_exist, @@ -26,6 +27,7 @@ from app.assets.helpers import normalize_tags from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.path_utils import ( compute_relative_filename, + get_name_and_tags_from_asset_path, resolve_destination_from_tags, validate_path_within_base, ) @@ -65,7 +67,7 @@ def _ingest_file_from_path( with create_session() as session: if preview_id: - if preview_id not in get_existing_asset_ids(session, [preview_id]): + if not reference_exists(session, preview_id): preview_id = None asset, asset_created, asset_updated = upsert_asset( @@ -135,6 +137,8 @@ def _register_existing_asset( tags: list[str] | None = None, tag_origin: str = "manual", owner_id: str = "", + mime_type: str | None = None, + preview_id: str | None = None, ) -> RegisterAssetResult: user_metadata = user_metadata or {} @@ -143,14 +147,25 @@ def _register_existing_asset( if not asset: raise ValueError(f"No asset with hash {asset_hash}") + if mime_type and not asset.mime_type: + update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type) + + if preview_id: + if not reference_exists(session, preview_id): + preview_id = None + ref, ref_created = get_or_create_reference( session, asset_id=asset.id, owner_id=owner_id, name=name, + preview_id=preview_id, ) if not ref_created: + if preview_id and ref.preview_id != preview_id: + ref.preview_id = preview_id + tag_names = get_reference_tags(session, reference_id=ref.id) result = RegisterAssetResult( ref=extract_reference_data(ref), @@ -242,6 +257,8 @@ def upload_from_temp_path( client_filename: str | None = None, owner_id: str = "", expected_hash: str | None = None, + mime_type: str | None = None, + preview_id: str | None = None, ) -> UploadResult: try: digest, _ = hashing.compute_blake3_hash(temp_path) @@ -270,6 +287,8 @@ def upload_from_temp_path( tags=tags or [], tag_origin="manual", owner_id=owner_id, + mime_type=mime_type, + preview_id=preview_id, ) return UploadResult( ref=result.ref, @@ -291,7 +310,7 @@ def upload_from_temp_path( dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) validate_path_within_base(dest_abs, base_dir) - content_type = ( + content_type = mime_type or ( mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] or mimetypes.guess_type(hashed_basename, strict=False)[0] or "application/octet-stream" @@ -315,7 +334,7 @@ def upload_from_temp_path( mime_type=content_type, info_name=_sanitize_filename(name or client_filename, fallback=digest), owner_id=owner_id, - preview_id=None, + preview_id=preview_id, user_metadata=user_metadata or {}, tags=tags, tag_origin="manual", @@ -342,30 +361,99 @@ def upload_from_temp_path( ) +def register_file_in_place( + abs_path: str, + name: str, + tags: list[str], + owner_id: str = "", + mime_type: str | None = None, +) -> UploadResult: + """Register an already-saved file in the asset database without moving it. + + Tags are derived from the filesystem path (root category + subfolder names), + merged with any caller-provided tags, matching the behavior of the scanner. + If the path is not under a known root, only the caller-provided tags are used. + """ + try: + _, path_tags = get_name_and_tags_from_asset_path(abs_path) + except ValueError: + path_tags = [] + merged_tags = normalize_tags([*path_tags, *tags]) + + try: + digest, _ = hashing.compute_blake3_hash(abs_path) + except ImportError as e: + raise DependencyMissingError(str(e)) + except Exception as e: + raise RuntimeError(f"failed to hash file: {e}") + asset_hash = "blake3:" + digest + + size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path) + content_type = mime_type or ( + mimetypes.guess_type(abs_path, strict=False)[0] + or "application/octet-stream" + ) + + ingest_result = _ingest_file_from_path( + abs_path=abs_path, + asset_hash=asset_hash, + size_bytes=size_bytes, + mtime_ns=mtime_ns, + mime_type=content_type, + info_name=_sanitize_filename(name, fallback=digest), + owner_id=owner_id, + tags=merged_tags, + tag_origin="upload", + require_existing_tags=False, + ) + reference_id = ingest_result.reference_id + if not reference_id: + raise RuntimeError("failed to create asset reference") + + with create_session() as session: + pair = fetch_reference_and_asset( + session, reference_id=reference_id, owner_id=owner_id + ) + if not pair: + raise RuntimeError("inconsistent DB state after ingest") + ref, asset = pair + tag_names = get_reference_tags(session, reference_id=ref.id) + + return UploadResult( + ref=extract_reference_data(ref), + asset=extract_asset_data(asset), + tags=tag_names, + created_new=ingest_result.asset_created, + ) + + def create_from_hash( hash_str: str, name: str, tags: list[str] | None = None, user_metadata: dict | None = None, owner_id: str = "", + mime_type: str | None = None, + preview_id: str | None = None, ) -> UploadResult | None: canonical = hash_str.strip().lower() - with create_session() as session: - asset = get_asset_by_hash(session, asset_hash=canonical) - if not asset: - return None - - result = _register_existing_asset( - asset_hash=canonical, - name=_sanitize_filename( - name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical - ), - user_metadata=user_metadata or {}, - tags=tags or [], - tag_origin="manual", - owner_id=owner_id, - ) + try: + result = _register_existing_asset( + asset_hash=canonical, + name=_sanitize_filename( + name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical + ), + user_metadata=user_metadata or {}, + tags=tags or [], + tag_origin="manual", + owner_id=owner_id, + mime_type=mime_type, + preview_id=preview_id, + ) + except ValueError: + logging.warning("create_from_hash: no asset found for hash %s", canonical) + return None return UploadResult( ref=result.ref, diff --git a/app/assets/services/schemas.py b/app/assets/services/schemas.py index 8b1f1f4dc..0eb128f58 100644 --- a/app/assets/services/schemas.py +++ b/app/assets/services/schemas.py @@ -25,7 +25,9 @@ class ReferenceData: preview_id: str | None created_at: datetime updated_at: datetime - last_access_time: datetime | None + system_metadata: dict[str, Any] | None = None + job_id: str | None = None + last_access_time: datetime | None = None @dataclass(frozen=True) @@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData: file_path=ref.file_path, user_metadata=ref.user_metadata, preview_id=ref.preview_id, + system_metadata=ref.system_metadata, + job_id=ref.job_id, created_at=ref.created_at, updated_at=ref.updated_at, last_access_time=ref.last_access_time, diff --git a/app/assets/services/tagging.py b/app/assets/services/tagging.py index 28900464d..37b612753 100644 --- a/app/assets/services/tagging.py +++ b/app/assets/services/tagging.py @@ -1,3 +1,5 @@ +from typing import Sequence + from app.assets.database.queries import ( AddTagsResult, RemoveTagsResult, @@ -6,6 +8,7 @@ from app.assets.database.queries import ( list_tags_with_usage, remove_tags_from_reference, ) +from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets from app.assets.services.schemas import TagUsage from app.database.db import create_session @@ -73,3 +76,23 @@ def list_tags( ) return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total + + +def list_tag_histogram( + owner_id: str = "", + include_tags: Sequence[str] | None = None, + exclude_tags: Sequence[str] | None = None, + name_contains: str | None = None, + metadata_filter: dict | None = None, + limit: int = 100, +) -> dict[str, int]: + with create_session() as session: + return list_tag_counts_for_filtered_assets( + session, + owner_id=owner_id, + include_tags=include_tags, + exclude_tags=exclude_tags, + name_contains=name_contains, + metadata_filter=metadata_filter, + limit=limit, + ) diff --git a/app/database/models.py b/app/database/models.py index e7572677a..b02856f6e 100644 --- a/app/database/models.py +++ b/app/database/models.py @@ -1,9 +1,18 @@ from typing import Any from datetime import datetime +from sqlalchemy import MetaData from sqlalchemy.orm import DeclarativeBase +NAMING_CONVENTION = { + "ix": "ix_%(table_name)s_%(column_0_N_name)s", + "uq": "uq_%(table_name)s_%(column_0_N_name)s", + "ck": "ck_%(table_name)s_%(constraint_name)s", + "fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s", + "pk": "pk_%(table_name)s", +} + class Base(DeclarativeBase): - pass + metadata = MetaData(naming_convention=NAMING_CONVENTION) def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: fields = obj.__table__.columns.keys() diff --git a/app/user_manager.py b/app/user_manager.py index e2c00dab2..e18afb71b 100644 --- a/app/user_manager.py +++ b/app/user_manager.py @@ -6,6 +6,7 @@ import uuid import glob import shutil import logging +import tempfile from aiohttp import web from urllib import parse from comfy.cli_args import args @@ -377,8 +378,15 @@ class UserManager(): try: body = await request.read() - with open(path, "wb") as f: - f.write(body) + dir_name = os.path.dirname(path) + fd, tmp_path = tempfile.mkstemp(dir=dir_name) + try: + with os.fdopen(fd, "wb") as f: + f.write(body) + os.replace(tmp_path, path) + except: + os.unlink(tmp_path) + raise except OSError as e: logging.warning(f"Error saving file '{path}': {e}") return web.Response( diff --git a/comfy/cli_args.py b/comfy/cli_args.py index 0a0bf2f30..13612175e 100644 --- a/comfy/cli_args.py +++ b/comfy/cli_args.py @@ -149,6 +149,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") +parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") @@ -262,4 +263,6 @@ else: args.fast = set(args.fast) def enables_dynamic_vram(): + if args.enable_dynamic_vram: + return True return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu diff --git a/comfy/context_windows.py b/comfy/context_windows.py index b54f7f39a..cb44ee6e8 100644 --- a/comfy/context_windows.py +++ b/comfy/context_windows.py @@ -93,6 +93,50 @@ class IndexListCallbacks: return {} +def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]): + if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)): + return None + cond_tensor = cond_value.cond + if temporal_dim >= cond_tensor.ndim: + return None + + cond_size = cond_tensor.size(temporal_dim) + + if temporal_scale == 1: + expected_size = x_in.size(window.dim) - temporal_offset + if cond_size != expected_size: + return None + + if temporal_offset == 0 and temporal_scale == 1: + sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list) + return cond_value._copy_with(sliced) + + # skip leading latent positions that have no corresponding conditioning (e.g. reference frames) + if temporal_offset > 0: + indices = [i - temporal_offset for i in window.index_list[temporal_offset:]] + indices = [i for i in indices if 0 <= i] + else: + indices = list(window.index_list) + + if not indices: + return None + + if temporal_scale > 1: + scaled = [] + for i in indices: + for k in range(temporal_scale): + si = i * temporal_scale + k + if si < cond_size: + scaled.append(si) + indices = scaled + if not indices: + return None + + idx = tuple([slice(None)] * temporal_dim + [indices]) + sliced = cond_tensor[idx].to(device) + return cond_value._copy_with(sliced) + + @dataclass class ContextSchedule: name: str @@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC): new_cond_item[cond_key] = result handled = True break + if not handled and self._model is not None: + result = self._model.resize_cond_for_context_window( + cond_key, cond_value, window, x_in, device, + retain_index_list=self.cond_retain_index_list) + if result is not None: + new_cond_item[cond_key] = result + handled = True if handled: continue if isinstance(cond_value, torch.Tensor): - if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ + if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \ (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): new_cond_item[cond_key] = window.get_tensor(cond_value, device) # Handle audio_embed (temporal dim is 1) @@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC): return context_windows def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): + self._model = model self.set_step(timestep, model_options) context_windows = self.get_context_windows(model, x_in, model_options) enumerated_context_windows = list(enumerate(context_windows)) diff --git a/comfy/ldm/cascade/stage_a.py b/comfy/ldm/cascade/stage_a.py index 145e6e69a..e4e30cacd 100644 --- a/comfy/ldm/cascade/stage_a.py +++ b/comfy/ldm/cascade/stage_a.py @@ -136,16 +136,7 @@ class ResBlock(nn.Module): ops.Linear(c_hidden, c), ) - self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) - - # Init weights - def _basic_init(module): - if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d): - torch.nn.init.xavier_uniform_(module.weight) - if module.bias is not None: - nn.init.constant_(module.bias, 0) - - self.apply(_basic_init) + self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False) def _norm(self, x, norm): return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) diff --git a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py index d48d9d642..f67ba84e9 100644 --- a/comfy/ldm/hunyuan3dv2_1/hunyuandit.py +++ b/comfy/ldm/hunyuan3dv2_1/hunyuandit.py @@ -343,6 +343,7 @@ class CrossAttention(nn.Module): k.reshape(b, s2, self.num_heads * self.head_dim), v, heads=self.num_heads, + low_precision_attention=False, ) out = self.out_proj(x) @@ -412,6 +413,7 @@ class Attention(nn.Module): key.reshape(B, N, self.num_heads * self.head_dim), value, heads=self.num_heads, + low_precision_attention=False, ) x = self.out_proj(x) diff --git a/comfy/ldm/lightricks/vae/causal_conv3d.py b/comfy/ldm/lightricks/vae/causal_conv3d.py index b8341edbc..7515f0d4e 100644 --- a/comfy/ldm/lightricks/vae/causal_conv3d.py +++ b/comfy/ldm/lightricks/vae/causal_conv3d.py @@ -23,6 +23,11 @@ class CausalConv3d(nn.Module): self.in_channels = in_channels self.out_channels = out_channels + if isinstance(stride, int): + self.time_stride = stride + else: + self.time_stride = stride[0] + kernel_size = (kernel_size, kernel_size, kernel_size) self.time_kernel_size = kernel_size[0] @@ -58,16 +63,25 @@ class CausalConv3d(nn.Module): pieces = [ cached, x ] if is_end and not causal: pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))) + input_length = sum([piece.shape[2] for piece in pieces]) + cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride) needs_caching = not is_end - if needs_caching and x.shape[2] >= self.time_kernel_size - 1: + if needs_caching and cache_length == 0: + self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False) needs_caching = False - self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + if needs_caching and x.shape[2] >= cache_length: + needs_caching = False + self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False) x = torch.cat(pieces, dim=2) + del pieces + del cached if needs_caching: - self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) + self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False) + elif is_end: + self.temporal_cache_state[tid] = (None, True) return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :] diff --git a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py index 9f14f64a5..998122c85 100644 --- a/comfy/ldm/lightricks/vae/causal_video_autoencoder.py +++ b/comfy/ldm/lightricks/vae/causal_video_autoencoder.py @@ -233,10 +233,7 @@ class Encoder(nn.Module): self.gradient_checkpointing = False - def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor: - r"""The forward method of the `Encoder` class.""" - - sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]: sample = self.conv_in(sample) checkpoint_fn = ( @@ -247,10 +244,14 @@ class Encoder(nn.Module): for down_block in self.down_blocks: sample = checkpoint_fn(down_block)(sample) + if sample is None or sample.shape[2] == 0: + return None sample = self.conv_norm_out(sample) sample = self.conv_act(sample) sample = self.conv_out(sample) + if sample is None or sample.shape[2] == 0: + return None if self.latent_log_var == "uniform": last_channel = sample[:, -1:, ...] @@ -282,9 +283,35 @@ class Encoder(nn.Module): return sample + def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor: + r"""The forward method of the `Encoder` class.""" + + max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder + frame_size = sample[:, :, :1, :, :].numel() * sample.element_size() + frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels)) + + outputs = [] + samples = [sample[:, :, :1, :, :]] + if sample.shape[2] > 1: + chunk_t = max(2, max_chunk_size // frame_size) + if chunk_t < 4: + chunk_t = 2 + elif chunk_t < 8: + chunk_t = 4 + else: + chunk_t = (chunk_t // 8) * 8 + samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2)) + for chunk_idx, chunk in enumerate(samples): + if chunk_idx == len(samples) - 1: + mark_conv3d_ended(self) + chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device) + output = self._forward_chunk(chunk) + if output is not None: + outputs.append(output) + + return torch_cat_if_needed(outputs, dim=2) + def forward(self, *args, **kwargs): - #No encoder support so just flag the end so it doesnt use the cache. - mark_conv3d_ended(self) try: return self.forward_orig(*args, **kwargs) finally: @@ -297,7 +324,23 @@ class Encoder(nn.Module): module.temporal_cache_state.pop(tid, None) -MAX_CHUNK_SIZE=(128 * 1024 ** 2) +MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3 +MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3 +MIN_CHUNK_SIZE = 32 * 1024 ** 2 +MAX_CHUNK_SIZE = 128 * 1024 ** 2 + +def get_max_chunk_size(device: torch.device) -> int: + total_memory = comfy.model_management.get_total_memory(dev=device) + + if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING: + return MIN_CHUNK_SIZE + if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING: + return MAX_CHUNK_SIZE + + interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / ( + MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING + ) + return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE)) class Decoder(nn.Module): r""" @@ -457,6 +500,17 @@ class Decoder(nn.Module): self.gradient_checkpointing = False + # Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset) + ts, hs, ws, to = 1, 1, 1, 0 + for block in self.up_blocks: + if isinstance(block, DepthToSpaceUpsample): + ts *= block.stride[0] + hs *= block.stride[1] + ws *= block.stride[2] + if block.stride[0] > 1: + to = to * block.stride[0] + 1 + self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to) + self.timestep_conditioning = timestep_conditioning if timestep_conditioning: @@ -478,11 +532,62 @@ class Decoder(nn.Module): ) - # def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: + def decode_output_shape(self, input_shape): + c, (ts, hs, ws), to = self._output_scale + return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws) + + def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size): + sample = sample_ref[0] + sample_ref[0] = None + if idx >= len(self.up_blocks): + sample = self.conv_norm_out(sample) + if timestep_shift_scale is not None: + shift, scale = timestep_shift_scale + sample = sample * (1 + scale) + shift + sample = self.conv_act(sample) + if ended: + mark_conv3d_ended(self.conv_out) + sample = self.conv_out(sample, causal=self.causal) + if sample is not None and sample.shape[2] > 0: + sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) + t = sample.shape[2] + output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample) + output_offset[0] += t + return + + up_block = self.up_blocks[idx] + if ended: + mark_conv3d_ended(up_block) + if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): + sample = checkpoint_fn(up_block)( + sample, causal=self.causal, timestep=scaled_timestep + ) + else: + sample = checkpoint_fn(up_block)(sample, causal=self.causal) + + if sample is None or sample.shape[2] == 0: + return + + total_bytes = sample.numel() * sample.element_size() + num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size + + if num_chunks == 1: + # when we are not chunking, detach our x so the callee can free it as soon as they are done + next_sample_ref = [sample] + del sample + self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + return + else: + samples = torch.chunk(sample, chunks=num_chunks, dim=2) + + for chunk_idx, sample1 in enumerate(samples): + self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) + def forward_orig( self, sample: torch.FloatTensor, timestep: Optional[torch.Tensor] = None, + output_buffer: Optional[torch.Tensor] = None, ) -> torch.FloatTensor: r"""The forward method of the `Decoder` class.""" batch_size = sample.shape[0] @@ -497,6 +602,7 @@ class Decoder(nn.Module): ) timestep_shift_scale = None + scaled_timestep = None if self.timestep_conditioning: assert ( timestep is not None @@ -524,48 +630,18 @@ class Decoder(nn.Module): ) timestep_shift_scale = ada_values.unbind(dim=1) - output = [] + if output_buffer is None: + output_buffer = torch.empty( + self.decode_output_shape(sample.shape), + dtype=sample.dtype, device=comfy.model_management.intermediate_device(), + ) + output_offset = [0] - def run_up(idx, sample, ended): - if idx >= len(self.up_blocks): - sample = self.conv_norm_out(sample) - if timestep_shift_scale is not None: - shift, scale = timestep_shift_scale - sample = sample * (1 + scale) + shift - sample = self.conv_act(sample) - if ended: - mark_conv3d_ended(self.conv_out) - sample = self.conv_out(sample, causal=self.causal) - if sample is not None and sample.shape[2] > 0: - output.append(sample.to(comfy.model_management.intermediate_device())) - return + max_chunk_size = get_max_chunk_size(sample.device) - up_block = self.up_blocks[idx] - if (ended): - mark_conv3d_ended(up_block) - if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D): - sample = checkpoint_fn(up_block)( - sample, causal=self.causal, timestep=scaled_timestep - ) - else: - sample = checkpoint_fn(up_block)(sample, causal=self.causal) + self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size) - if sample is None or sample.shape[2] == 0: - return - - total_bytes = sample.numel() * sample.element_size() - num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE - samples = torch.chunk(sample, chunks=num_chunks, dim=2) - - for chunk_idx, sample1 in enumerate(samples): - run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1) - - run_up(0, sample, True) - sample = torch.cat(output, dim=2) - - sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1) - - return sample + return output_buffer def forward(self, *args, **kwargs): try: @@ -689,12 +765,25 @@ class SpaceToDepthDownsample(nn.Module): causal=True, spatial_padding_mode=spatial_padding_mode, ) + self.temporal_cache_state = {} def forward(self, x, causal: bool = True): - if self.stride[0] == 2: + tid = threading.get_ident() + cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None)) + if cached_input is not None: + x = torch_cat_if_needed([cached_input, x], dim=2) + cached_input = None + + if self.stride[0] == 2 and pad_first: x = torch.cat( [x[:, :, :1, :, :], x], dim=2 ) # duplicate first frames for padding + pad_first = False + + if x.shape[2] < self.stride[0]: + cached_input = x + self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input) + return None # skip connection x_in = rearrange( @@ -709,15 +798,26 @@ class SpaceToDepthDownsample(nn.Module): # conv x = self.conv(x, causal=causal) - x = rearrange( - x, - "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", - p1=self.stride[0], - p2=self.stride[1], - p3=self.stride[2], - ) + if self.stride[0] == 2 and x.shape[2] == 1: + if cached_x is not None: + x = torch_cat_if_needed([cached_x, x], dim=2) + cached_x = None + else: + cached_x = x + x = None - x = x + x_in + if x is not None: + x = rearrange( + x, + "b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", + p1=self.stride[0], + p2=self.stride[1], + p3=self.stride[2], + ) + + cached = add_exchange_cache(x, cached, x_in, dim=2) + + self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input) return x @@ -1050,6 +1150,8 @@ class processor(nn.Module): return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) class VideoVAE(nn.Module): + comfy_has_chunked_io = True + def __init__(self, version=0, config=None): super().__init__() @@ -1192,14 +1294,15 @@ class VideoVAE(nn.Module): } return config - def encode(self, x): - frames_count = x.shape[2] - if ((frames_count - 1) % 8) != 0: - raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.") - means, logvar = torch.chunk(self.encoder(x), 2, dim=1) + def encode(self, x, device=None): + x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :] + means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1) return self.per_channel_statistics.normalize(means) - def decode(self, x): + def decode_output_shape(self, input_shape): + return self.decoder.decode_output_shape(input_shape) + + def decode(self, x, output_buffer=None): if self.timestep_conditioning: #TODO: seed x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x - return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) + return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer) diff --git a/comfy/ldm/wan/vae.py b/comfy/ldm/wan/vae.py index 71f73c64e..deeb8695b 100644 --- a/comfy/ldm/wan/vae.py +++ b/comfy/ldm/wan/vae.py @@ -99,7 +99,7 @@ class Resample(nn.Module): else: self.resample = nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): b, c, t, h, w = x.size() if self.mode == 'upsample3d': if feat_cache is not None: @@ -109,22 +109,7 @@ class Resample(nn.Module): feat_idx[0] += 1 else: - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] != 'Rep': - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - if cache_x.shape[2] < 2 and feat_cache[ - idx] is not None and feat_cache[idx] == 'Rep': - cache_x = torch.cat([ - torch.zeros_like(cache_x).to(cache_x.device), - cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] if feat_cache[idx] == 'Rep': x = self.time_conv(x) else: @@ -145,19 +130,24 @@ class Resample(nn.Module): if feat_cache is not None: idx = feat_idx[0] if feat_cache[idx] is None: - feat_cache[idx] = x.clone() - feat_idx[0] += 1 + feat_cache[idx] = x else: - cache_x = x[:, :, -1:, :, :].clone() - # if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep': - # # cache last frame of last two chunk - # cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2) - + cache_x = x[:, :, -1:, :, :] x = self.time_conv( torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) feat_cache[idx] = cache_x - feat_idx[0] += 1 + + deferred_x = feat_cache[idx + 1] + if deferred_x is not None: + x = torch.cat([deferred_x, x], 2) + feat_cache[idx + 1] = None + + if x.shape[2] == 1 and not final: + feat_cache[idx + 1] = x + x = None + + feat_idx[0] += 2 return x @@ -177,19 +167,12 @@ class ResidualBlock(nn.Module): self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ if in_dim != out_dim else nn.Identity() - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): old_x = x for layer in self.residual: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, cache_list=feat_cache, cache_idx=idx) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -213,7 +196,7 @@ class AttentionBlock(nn.Module): self.proj = ops.Conv2d(dim, dim, 1) self.optimized_attention = vae_attention() - def forward(self, x): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): identity = x b, c, t, h, w = x.size() x = rearrange(x, 'b c t h w -> (b t) c h w') @@ -283,17 +266,10 @@ class Encoder3d(nn.Module): RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, z_dim, 3, padding=1)) - def forward(self, x, feat_cache=None, feat_idx=[0]): + def forward(self, x, feat_cache=None, feat_idx=[0], final=False): if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -303,14 +279,16 @@ class Encoder3d(nn.Module): ## downsamples for layer in self.downsamples: if feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + x = layer(x, feat_cache, feat_idx, final=final) + if x is None: + return None else: x = layer(x) ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx, final=final) else: x = layer(x) @@ -318,14 +296,7 @@ class Encoder3d(nn.Module): for layer in self.head: if isinstance(layer, CausalConv3d) and feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = layer(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -389,18 +360,48 @@ class Decoder3d(nn.Module): RMS_norm(out_dim, images=False), nn.SiLU(), CausalConv3d(out_dim, output_channels, 3, padding=1)) + def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks): + x = x_ref[0] + x_ref[0] = None + if layer_idx >= len(self.upsamples): + for layer in self.head: + if isinstance(layer, CausalConv3d) and feat_cache is not None: + cache_x = x[:, :, -CACHE_T:, :, :] + x = layer(x, feat_cache[feat_idx[0]]) + feat_cache[feat_idx[0]] = cache_x + feat_idx[0] += 1 + else: + x = layer(x) + out_chunks.append(x) + return + + layer = self.upsamples[layer_idx] + if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1: + for frame_idx in range(x.shape[2]): + self.run_up( + layer_idx, + [x[:, :, frame_idx:frame_idx + 1, :, :]], + feat_cache, + feat_idx.copy(), + out_chunks, + ) + del x + return + + if feat_cache is not None: + x = layer(x, feat_cache, feat_idx) + else: + x = layer(x) + + next_x_ref = [x] + del x + self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks) + def forward(self, x, feat_cache=None, feat_idx=[0]): ## conv1 if feat_cache is not None: idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) + cache_x = x[:, :, -CACHE_T:, :, :] x = self.conv1(x, feat_cache[idx]) feat_cache[idx] = cache_x feat_idx[0] += 1 @@ -409,42 +410,21 @@ class Decoder3d(nn.Module): ## middle for layer in self.middle: - if isinstance(layer, ResidualBlock) and feat_cache is not None: - x = layer(x, feat_cache, feat_idx) - else: - x = layer(x) - - ## upsamples - for layer in self.upsamples: if feat_cache is not None: x = layer(x, feat_cache, feat_idx) else: x = layer(x) - ## head - for layer in self.head: - if isinstance(layer, CausalConv3d) and feat_cache is not None: - idx = feat_idx[0] - cache_x = x[:, :, -CACHE_T:, :, :].clone() - if cache_x.shape[2] < 2 and feat_cache[idx] is not None: - # cache last frame of last two chunk - cache_x = torch.cat([ - feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to( - cache_x.device), cache_x - ], - dim=2) - x = layer(x, feat_cache[idx]) - feat_cache[idx] = cache_x - feat_idx[0] += 1 - else: - x = layer(x) - return x + out_chunks = [] + + self.run_up(0, [x], feat_cache, feat_idx, out_chunks) + return out_chunks -def count_conv3d(model): +def count_cache_layers(model): count = 0 for m in model.modules(): - if isinstance(m, CausalConv3d): + if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'): count += 1 return count @@ -482,11 +462,12 @@ class WanVAE(nn.Module): conv_idx = [0] ## cache t = x.shape[2] - iter_ = 1 + (t - 1) // 4 + t = 1 + ((t - 1) // 4) * 4 + iter_ = 1 + (t - 1) // 2 feat_map = None if iter_ > 1: - feat_map = [None] * count_conv3d(self.encoder) - ## 对encode输入的x,按时间拆分为1、4、4、4.... + feat_map = [None] * count_cache_layers(self.encoder) + ## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整) for i in range(iter_): conv_idx = [0] if i == 0: @@ -496,20 +477,23 @@ class WanVAE(nn.Module): feat_idx=conv_idx) else: out_ = self.encoder( - x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], + x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, - feat_idx=conv_idx) + feat_idx=conv_idx, + final=(i == (iter_ - 1))) + if out_ is None: + continue out = torch.cat([out, out_], 2) + mu, log_var = self.conv1(out).chunk(2, dim=1) return mu def decode(self, z): - conv_idx = [0] # z: [b,c,t,h,w] - iter_ = z.shape[2] + iter_ = 1 + z.shape[2] // 2 feat_map = None if iter_ > 1: - feat_map = [None] * count_conv3d(self.decoder) + feat_map = [None] * count_cache_layers(self.decoder) x = self.conv2(z) for i in range(iter_): conv_idx = [0] @@ -520,8 +504,8 @@ class WanVAE(nn.Module): feat_idx=conv_idx) else: out_ = self.decoder( - x[:, :, i:i + 1, :, :], + x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :], feat_cache=feat_map, feat_idx=conv_idx) - out = torch.cat([out, out_], 2) - return out + out += out_ + return torch.cat(out, 2) diff --git a/comfy/memory_management.py b/comfy/memory_management.py index 563224098..f9078fe7c 100644 --- a/comfy/memory_management.py +++ b/comfy/memory_management.py @@ -39,7 +39,10 @@ def read_tensor_file_slice_into(tensor, destination): if (destination.device.type != "cpu" or file_obj is None or threading.get_ident() != info.thread_id - or destination.numel() * destination.element_size() < info.size): + or destination.numel() * destination.element_size() < info.size + or tensor.numel() * tensor.element_size() != info.size + or tensor.storage_offset() != 0 + or not tensor.is_contiguous()): return False if info.size == 0: diff --git a/comfy/model_base.py b/comfy/model_base.py index d9d5a9293..43ec93324 100644 --- a/comfy/model_base.py +++ b/comfy/model_base.py @@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit import torch import logging import comfy.ldm.lightricks.av_model +import comfy.context_windows from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_b import StageB @@ -285,6 +286,12 @@ class BaseModel(torch.nn.Module): return data return None + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + """Override in subclasses to handle model-specific cond slicing for context windows. + Return a sliced cond object, or None to fall through to default handling. + Use comfy.context_windows.slice_cond() for common cases.""" + return None + def extra_conds(self, **kwargs): out = {} concat_cond = self.concat_cond(**kwargs) @@ -1375,6 +1382,11 @@ class WAN21_Vace(WAN21): out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "vace_context": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN21_Camera(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) @@ -1427,6 +1439,11 @@ class WAN21_HuMo(WAN21): return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_Animate(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) @@ -1444,6 +1461,13 @@ class WAN22_Animate(WAN21): out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "face_pixel_values": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1) + if cond_key == "pose_latents": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22_S2V(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) @@ -1480,6 +1504,11 @@ class WAN22_S2V(WAN21): out['reference_motion'] = reference_motion.shape return out + def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]): + if cond_key == "audio_embed": + return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1) + return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list) + class WAN22(WAN21): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) diff --git a/comfy/model_management.py b/comfy/model_management.py index 442d5a40a..2c250dacc 100644 --- a/comfy/model_management.py +++ b/comfy/model_management.py @@ -400,7 +400,7 @@ try: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much - if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 + if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950 ENABLE_PYTORCH_ATTENTION = True if rocm_version >= (7, 0): if any((a in arch) for a in ["gfx1200", "gfx1201"]): @@ -541,6 +541,7 @@ class LoadedModel: if model.parent is not None: self._parent_model = weakref.ref(model.parent) self._patcher_finalizer = weakref.finalize(model, self._switch_parent) + self._patcher_finalizer.atexit = False def _switch_parent(self): model = self._parent_model() @@ -587,6 +588,7 @@ class LoadedModel: self.real_model = weakref.ref(real_model) self.model_finalizer = weakref.finalize(real_model, cleanup_models) + self.model_finalizer.atexit = False return real_model def should_reload_model(self, force_patch_weights=False): diff --git a/comfy/ops.py b/comfy/ops.py index 59c0df87d..1518ec9de 100644 --- a/comfy/ops.py +++ b/comfy/ops.py @@ -336,7 +336,10 @@ class disable_weight_init: class Linear(torch.nn.Linear, CastWeightBiasOp): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): - if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + # don't trust subclasses that BYO state dict loader to call us. + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict): super().__init__(in_features, out_features, bias, device, dtype) return @@ -357,7 +360,9 @@ class disable_weight_init: def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict): return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) disable_weight_init._lazy_load_from_state_dict( @@ -564,7 +569,10 @@ class disable_weight_init: def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None, norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None, _freeze=False, device=None, dtype=None): - if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + # don't trust subclasses that BYO state dict loader to call us. + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict): super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm, norm_type, scale_grad_by_freq, sparse, _weight, _freeze, device, dtype) @@ -590,7 +598,9 @@ class disable_weight_init: def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs): - if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: + if (not comfy.model_management.WINDOWS + or not comfy.memory_management.aimdo_enabled + or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict): return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, missing_keys, unexpected_keys, error_msgs) disable_weight_init._lazy_load_from_state_dict( @@ -766,6 +776,71 @@ from .quant_ops import ( ) +class QuantLinearFunc(torch.autograd.Function): + """Custom autograd function for quantized linear: quantized forward, compute_dtype backward. + Handles any input rank by flattening to 2D for matmul and restoring shape after. + """ + + @staticmethod + def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype): + input_shape = input_float.shape + inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D + + # Quantize input (same as inference path) + if layout_type is not None: + q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale) + else: + q_input = inp + + w = weight.detach() if weight.requires_grad else weight + b = bias.detach() if bias is not None and bias.requires_grad else bias + + output = torch.nn.functional.linear(q_input, w, b) + + # Restore original input shape + if len(input_shape) > 2: + output = output.unflatten(0, input_shape[:-1]) + + ctx.save_for_backward(input_float, weight) + ctx.input_shape = input_shape + ctx.has_bias = bias is not None + ctx.compute_dtype = compute_dtype + ctx.weight_requires_grad = weight.requires_grad + + return output + + @staticmethod + @torch.autograd.function.once_differentiable + def backward(ctx, grad_output): + input_float, weight = ctx.saved_tensors + compute_dtype = ctx.compute_dtype + grad_2d = grad_output.flatten(0, -2).to(compute_dtype) + + # Dequantize weight to compute dtype for backward matmul + if isinstance(weight, QuantizedTensor): + weight_f = weight.dequantize().to(compute_dtype) + else: + weight_f = weight.to(compute_dtype) + + # grad_input = grad_output @ weight + grad_input = torch.mm(grad_2d, weight_f) + if len(ctx.input_shape) > 2: + grad_input = grad_input.unflatten(0, ctx.input_shape[:-1]) + + # grad_weight (only if weight requires grad, typically frozen for quantized training) + grad_weight = None + if ctx.weight_requires_grad: + input_f = input_float.flatten(0, -2).to(compute_dtype) + grad_weight = torch.mm(grad_2d.t(), input_f) + + # grad_bias + grad_bias = None + if ctx.has_bias: + grad_bias = grad_2d.sum(dim=0) + + return grad_input, grad_weight, grad_bias, None, None, None + + def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): class MixedPrecisionOps(manual_cast): _quant_config = quant_config @@ -960,10 +1035,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec #If cast needs to apply lora, it should be done in the compute dtype compute_dtype = input.dtype - if (getattr(self, 'layout_type', None) is not None and + _use_quantized = ( + getattr(self, 'layout_type', None) is not None and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and not getattr(self, 'comfy_force_cast_weights', False) and - len(self.weight_function) == 0 and len(self.bias_function) == 0): + len(self.weight_function) == 0 and len(self.bias_function) == 0 + ) + + # Training path: quantized forward with compute_dtype backward via autograd function + if (input.requires_grad and _use_quantized): + + weight, bias, offload_stream = cast_bias_weight( + self, + input, + offloadable=True, + compute_dtype=compute_dtype, + want_requant=True + ) + + scale = getattr(self, 'input_scale', None) + if scale is not None: + scale = comfy.model_management.cast_to_device(scale, input.device, None) + + output = QuantLinearFunc.apply( + input, weight, bias, self.layout_type, scale, compute_dtype + ) + + uncast_bias_weight(self, weight, bias, offload_stream) + return output + + # Inference path (unchanged) + if _use_quantized: # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input @@ -1011,7 +1113,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec for key, param in self._parameters.items(): if param is None: continue - self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) + p = fn(param) + if p.is_inference(): + p = p.clone() + self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False)) for key, buf in self._buffers.items(): if buf is not None: self._buffers[key] = fn(buf) diff --git a/comfy/sample.py b/comfy/sample.py index a2a39b527..e9c2259ab 100644 --- a/comfy/sample.py +++ b/comfy/sample.py @@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.to(comfy.model_management.intermediate_device()) + samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return samples def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) - samples = samples.to(comfy.model_management.intermediate_device()) + samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return samples diff --git a/comfy/sd.py b/comfy/sd.py index 54fbe504c..e2645438c 100644 --- a/comfy/sd.py +++ b/comfy/sd.py @@ -456,7 +456,7 @@ class VAE: self.output_channels = 3 self.pad_channel_value = None self.process_input = lambda image: image * 2.0 - 1.0 - self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) + self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0) self.working_dtypes = [torch.bfloat16, torch.float32] self.disable_offload = False self.not_video = False @@ -952,12 +952,23 @@ class VAE: batch_number = int(free_memory / memory_used) batch_number = max(1, batch_number) + # Pre-allocate output for VAEs that support direct buffer writes + preallocated = False + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype()) + preallocated = True + for x in range(0, samples_in.shape[0], batch_number): - samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) - out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype())) - if pixel_samples is None: - pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) - pixel_samples[x:x+batch_number] = out + samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype) + if preallocated: + self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options) + else: + out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True) + if pixel_samples is None: + pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) + pixel_samples[x:x+batch_number].copy_(out) + del out + self.process_output(pixel_samples[x:x+batch_number]) except Exception as e: model_management.raise_non_oom(e) logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") @@ -968,6 +979,7 @@ class VAE: do_tile = True if do_tile: + comfy.model_management.soft_empty_cache() dims = samples_in.ndim - 2 if dims == 1 or self.extra_1d_channel is not None: pixel_samples = self.decode_tiled_1d(samples_in) @@ -1028,8 +1040,13 @@ class VAE: batch_number = max(1, batch_number) samples = None for x in range(0, pixel_samples.shape[0], batch_number): - pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device) - out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype()) + pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype) + if getattr(self.first_stage_model, 'comfy_has_chunked_io', False): + out = self.first_stage_model.encode(pixels_in, device=self.device) + else: + pixels_in = pixels_in.to(self.device) + out = self.first_stage_model.encode(pixels_in) + out = out.to(self.output_device).to(dtype=self.vae_output_dtype()) if samples is None: samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype()) samples[x:x + batch_number] = out @@ -1044,6 +1061,7 @@ class VAE: do_tile = True if do_tile: + comfy.model_management.soft_empty_cache() if self.latent_dim == 3: tile = 256 overlap = tile // 4 diff --git a/comfy/sd1_clip.py b/comfy/sd1_clip.py index ffb7666ff..897186bba 100644 --- a/comfy/sd1_clip.py +++ b/comfy/sd1_clip.py @@ -46,7 +46,7 @@ class ClipTokenWeightEncoder: out, pooled = o[:2] if pooled is not None: - first_pooled = pooled[0:1].to(model_management.intermediate_device()) + first_pooled = pooled[0:1].to(device=model_management.intermediate_device()) else: first_pooled = pooled @@ -63,16 +63,16 @@ class ClipTokenWeightEncoder: output.append(z) if (len(output) == 0): - r = (out[-1:].to(model_management.intermediate_device()), first_pooled) + r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled) else: - r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled) + r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled) if len(o) > 2: extra = {} for k in o[2]: v = o[2][k] if k == "attention_mask": - v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()) + v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device()) extra[k] = v r = r + (extra,) diff --git a/comfy/utils.py b/comfy/utils.py index 9931fe3b4..78c491b98 100644 --- a/comfy/utils.py +++ b/comfy/utils.py @@ -897,6 +897,10 @@ def set_attr(obj, attr, value): return prev def set_attr_param(obj, attr, value): + # Clone inference tensors (created under torch.inference_mode) since + # their version counter is frozen and nn.Parameter() cannot wrap them. + if (not torch.is_inference_mode_enabled()) and value.is_inference(): + value = value.clone() return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) def set_attr_buffer(obj, attr, value): @@ -1131,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am pbar.update(1) continue - out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) - out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) + out = output[b:b+1].zero_() + out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device) positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] @@ -1147,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am upscaled.append(round(get_pos(d, pos))) ps = function(s_in).to(output_device) - mask = torch.ones_like(ps) + mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device) for d in range(2, dims + 2): feather = round(get_scale(d - 2, overlap[d - 2])) @@ -1170,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am if pbar is not None: pbar.update(1) - output[b:b+1] = out/out_div + out.div_(out_div) return output def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): diff --git a/comfy_api_nodes/apis/gemini.py b/comfy_api_nodes/apis/gemini.py index 639035fef..22879fe18 100644 --- a/comfy_api_nodes/apis/gemini.py +++ b/comfy_api_nodes/apis/gemini.py @@ -67,6 +67,7 @@ class GeminiPart(BaseModel): inlineData: GeminiInlineData | None = Field(None) fileData: GeminiFileData | None = Field(None) text: str | None = Field(None) + thought: bool | None = Field(None) class GeminiTextPart(BaseModel): diff --git a/comfy_api_nodes/apis/quiver.py b/comfy_api_nodes/apis/quiver.py new file mode 100644 index 000000000..bc8708754 --- /dev/null +++ b/comfy_api_nodes/apis/quiver.py @@ -0,0 +1,43 @@ +from pydantic import BaseModel, Field + + +class QuiverImageObject(BaseModel): + url: str = Field(...) + + +class QuiverTextToSVGRequest(BaseModel): + model: str = Field(default="arrow-preview") + prompt: str = Field(...) + instructions: str | None = Field(default=None) + references: list[QuiverImageObject] | None = Field(default=None, max_length=4) + temperature: float | None = Field(default=None, ge=0, le=2) + top_p: float | None = Field(default=None, ge=0, le=1) + presence_penalty: float | None = Field(default=None, ge=-2, le=2) + + +class QuiverImageToSVGRequest(BaseModel): + model: str = Field(default="arrow-preview") + image: QuiverImageObject = Field(...) + auto_crop: bool | None = Field(default=None) + target_size: int | None = Field(default=None, ge=128, le=4096) + temperature: float | None = Field(default=None, ge=0, le=2) + top_p: float | None = Field(default=None, ge=0, le=1) + presence_penalty: float | None = Field(default=None, ge=-2, le=2) + + +class QuiverSVGResponseItem(BaseModel): + svg: str = Field(...) + mime_type: str | None = Field(default="image/svg+xml") + + +class QuiverSVGUsage(BaseModel): + total_tokens: int | None = Field(default=None) + input_tokens: int | None = Field(default=None) + output_tokens: int | None = Field(default=None) + + +class QuiverSVGResponse(BaseModel): + id: str | None = Field(default=None) + created: int | None = Field(default=None) + data: list[QuiverSVGResponseItem] = Field(...) + usage: QuiverSVGUsage | None = Field(default=None) diff --git a/comfy_api_nodes/nodes_bytedance.py b/comfy_api_nodes/nodes_bytedance.py index 6dbd5984e..de0c22e70 100644 --- a/comfy_api_nodes/nodes_bytedance.py +++ b/comfy_api_nodes/nodes_bytedance.py @@ -47,6 +47,10 @@ SEEDREAM_MODELS = { BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} +DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"} + +logger = logging.getLogger(__name__) + def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: if response.error: @@ -135,6 +139,7 @@ class ByteDanceImageNode(IO.ComfyNode): price_badge=IO.PriceBadge( expr="""{"type":"usd","usd":0.03}""", ), + is_deprecated=True, ) @classmethod @@ -942,7 +947,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode): ] return await process_video_task( cls, - payload=Image2VideoTaskCreationRequest(model=model, content=x), + payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), ) @@ -952,6 +957,12 @@ async def process_video_task( payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, estimated_duration: int | None, ) -> IO.NodeOutput: + if payload.model in DEPRECATED_MODELS: + logger.warning( + "Model '%s' is deprecated and will be deactivated on May 13, 2026. " + "Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.", + payload.model, + ) initial_response = await sync_op( cls, ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), diff --git a/comfy_api_nodes/nodes_gemini.py b/comfy_api_nodes/nodes_gemini.py index 8225ea67e..25d747e76 100644 --- a/comfy_api_nodes/nodes_gemini.py +++ b/comfy_api_nodes/nodes_gemini.py @@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge( $m := widgets.model; $r := widgets.resolution; $isFlash := $contains($m, "nano banana 2"); - $flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123}; + $flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154}; $proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24}; $prices := $isFlash ? $flashPrices : $proPrices; {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} @@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str: return "\n".join([part.text for part in parts]) -async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: +async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image: image_tensors: list[Input.Image] = [] parts = get_parts_by_type(response, "image/*") for part in parts: + if (part.thought is True) != thought: + continue if part.inlineData: image_data = base64.b64decode(part.inlineData.data) returned_image = bytesio_to_image_tensor(BytesIO(image_data)) @@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode): outputs=[ IO.Image.Output(), IO.String.Output(), + IO.Image.Output( + display_name="thought_image", + tooltip="First image from the model's thinking process. " + "Only available with thinking_level HIGH and IMAGE+TEXT modality.", + ), ], hidden=[ IO.Hidden.auth_token_comfy_org, @@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode): response_model=GeminiGenerateContentResponse, price_extractor=calculate_tokens_price, ) - return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) + return IO.NodeOutput( + await get_image_from_response(response), + get_text_from_response(response), + await get_image_from_response(response, thought=True), + ) class GeminiExtension(ComfyExtension): diff --git a/comfy_api_nodes/nodes_kling.py b/comfy_api_nodes/nodes_kling.py index 8963c335d..9a37ccc53 100644 --- a/comfy_api_nodes/nodes_kling.py +++ b/comfy_api_nodes/nodes_kling.py @@ -1459,6 +1459,7 @@ class OmniProEditVideoNode(IO.ComfyNode): node_id="KlingOmniProEditVideoNode", display_name="Kling 3.0 Omni Edit Video", category="api node/video/Kling", + essentials_category="Video Generation", description="Edit an existing video with the latest model from Kling.", inputs=[ IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), diff --git a/comfy_api_nodes/nodes_quiver.py b/comfy_api_nodes/nodes_quiver.py new file mode 100644 index 000000000..61533263f --- /dev/null +++ b/comfy_api_nodes/nodes_quiver.py @@ -0,0 +1,291 @@ +from io import BytesIO + +from typing_extensions import override + +from comfy_api.latest import IO, ComfyExtension +from comfy_api_nodes.apis.quiver import ( + QuiverImageObject, + QuiverImageToSVGRequest, + QuiverSVGResponse, + QuiverTextToSVGRequest, +) +from comfy_api_nodes.util import ( + ApiEndpoint, + sync_op, + upload_image_to_comfyapi, + validate_string, +) +from comfy_extras.nodes_images import SVG + + +class QuiverTextToSVGNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="QuiverTextToSVGNode", + display_name="Quiver Text to SVG", + category="api node/image/Quiver", + description="Generate an SVG from a text prompt using Quiver AI.", + inputs=[ + IO.String.Input( + "prompt", + multiline=True, + default="", + tooltip="Text description of the desired SVG output.", + ), + IO.String.Input( + "instructions", + multiline=True, + default="", + tooltip="Additional style or formatting guidance.", + optional=True, + ), + IO.Autogrow.Input( + "reference_images", + template=IO.Autogrow.TemplatePrefix( + IO.Image.Input("image"), + prefix="ref_", + min=0, + max=4, + ), + tooltip="Up to 4 reference images to guide the generation.", + optional=True, + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "arrow-preview", + [ + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Randomness control. Higher values increase randomness.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=1.0, + min=0.05, + max=1.0, + step=0.05, + display_mode=IO.NumberDisplay.slider, + tooltip="Nucleus sampling parameter.", + advanced=True, + ), + IO.Float.Input( + "presence_penalty", + default=0.0, + min=-2.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Token presence penalty.", + advanced=True, + ), + ], + ), + ], + tooltip="Model to use for SVG generation.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.429}""", + ), + ) + + @classmethod + async def execute( + cls, + prompt: str, + model: dict, + seed: int, + instructions: str = None, + reference_images: IO.Autogrow.Type = None, + ) -> IO.NodeOutput: + validate_string(prompt, strip_whitespace=False, min_length=1) + + references = None + if reference_images: + references = [] + for key in reference_images: + url = await upload_image_to_comfyapi(cls, reference_images[key]) + references.append(QuiverImageObject(url=url)) + if len(references) > 4: + raise ValueError("Maximum 4 reference images are allowed.") + + instructions_val = instructions.strip() if instructions else None + if instructions_val == "": + instructions_val = None + + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"), + response_model=QuiverSVGResponse, + data=QuiverTextToSVGRequest( + model=model["model"], + prompt=prompt, + instructions=instructions_val, + references=references, + temperature=model.get("temperature"), + top_p=model.get("top_p"), + presence_penalty=model.get("presence_penalty"), + ), + ) + + svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data] + return IO.NodeOutput(SVG(svg_data)) + + +class QuiverImageToSVGNode(IO.ComfyNode): + @classmethod + def define_schema(cls): + return IO.Schema( + node_id="QuiverImageToSVGNode", + display_name="Quiver Image to SVG", + category="api node/image/Quiver", + description="Vectorize a raster image into SVG using Quiver AI.", + inputs=[ + IO.Image.Input( + "image", + tooltip="Input image to vectorize.", + ), + IO.Boolean.Input( + "auto_crop", + default=False, + tooltip="Automatically crop to the dominant subject.", + ), + IO.DynamicCombo.Input( + "model", + options=[ + IO.DynamicCombo.Option( + "arrow-preview", + [ + IO.Int.Input( + "target_size", + default=1024, + min=128, + max=4096, + tooltip="Square resize target in pixels.", + ), + IO.Float.Input( + "temperature", + default=1.0, + min=0.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Randomness control. Higher values increase randomness.", + advanced=True, + ), + IO.Float.Input( + "top_p", + default=1.0, + min=0.05, + max=1.0, + step=0.05, + display_mode=IO.NumberDisplay.slider, + tooltip="Nucleus sampling parameter.", + advanced=True, + ), + IO.Float.Input( + "presence_penalty", + default=0.0, + min=-2.0, + max=2.0, + step=0.1, + display_mode=IO.NumberDisplay.slider, + tooltip="Token presence penalty.", + advanced=True, + ), + ], + ), + ], + tooltip="Model to use for SVG vectorization.", + ), + IO.Int.Input( + "seed", + default=0, + min=0, + max=2147483647, + control_after_generate=True, + tooltip="Seed to determine if node should re-run; " + "actual results are nondeterministic regardless of seed.", + ), + ], + outputs=[ + IO.SVG.Output(), + ], + hidden=[ + IO.Hidden.auth_token_comfy_org, + IO.Hidden.api_key_comfy_org, + IO.Hidden.unique_id, + ], + is_api_node=True, + price_badge=IO.PriceBadge( + expr="""{"type":"usd","usd":0.429}""", + ), + ) + + @classmethod + async def execute( + cls, + image, + auto_crop: bool, + model: dict, + seed: int, + ) -> IO.NodeOutput: + image_url = await upload_image_to_comfyapi(cls, image) + + response = await sync_op( + cls, + ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"), + response_model=QuiverSVGResponse, + data=QuiverImageToSVGRequest( + model=model["model"], + image=QuiverImageObject(url=image_url), + auto_crop=auto_crop if auto_crop else None, + target_size=model.get("target_size"), + temperature=model.get("temperature"), + top_p=model.get("top_p"), + presence_penalty=model.get("presence_penalty"), + ), + ) + + svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data] + return IO.NodeOutput(SVG(svg_data)) + + +class QuiverExtension(ComfyExtension): + @override + async def get_node_list(self) -> list[type[IO.ComfyNode]]: + return [ + QuiverTextToSVGNode, + QuiverImageToSVGNode, + ] + + +async def comfy_entrypoint() -> QuiverExtension: + return QuiverExtension() diff --git a/comfy_api_nodes/nodes_recraft.py b/comfy_api_nodes/nodes_recraft.py index 4d1d508fa..c60cfbc4a 100644 --- a/comfy_api_nodes/nodes_recraft.py +++ b/comfy_api_nodes/nodes_recraft.py @@ -833,6 +833,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode): node_id="RecraftVectorizeImageNode", display_name="Recraft Vectorize Image", category="api node/image/Recraft", + essentials_category="Image Tools", description="Generates SVG synchronously from an input image.", inputs=[ IO.Image.Input("image"), diff --git a/comfy_extras/nodes_audio.py b/comfy_extras/nodes_audio.py index 5d8d9bf6f..a395392d8 100644 --- a/comfy_extras/nodes_audio.py +++ b/comfy_extras/nodes_audio.py @@ -19,6 +19,7 @@ class EmptyLatentAudio(IO.ComfyNode): node_id="EmptyLatentAudio", display_name="Empty Latent Audio", category="latent/audio", + essentials_category="Audio", inputs=[ IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), IO.Int.Input( @@ -185,6 +186,7 @@ class SaveAudioMP3(IO.ComfyNode): search_aliases=["export mp3"], display_name="Save Audio (MP3)", category="audio", + essentials_category="Audio", inputs=[ IO.Audio.Input("audio"), IO.String.Input("filename_prefix", default="audio/ComfyUI"), diff --git a/comfy_extras/nodes_canny.py b/comfy_extras/nodes_canny.py index 5e7c4eabb..648b4279d 100644 --- a/comfy_extras/nodes_canny.py +++ b/comfy_extras/nodes_canny.py @@ -3,6 +3,7 @@ from typing_extensions import override import comfy.model_management from comfy_api.latest import ComfyExtension, io +import torch class Canny(io.ComfyNode): @@ -29,8 +30,8 @@ class Canny(io.ComfyNode): @classmethod def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput: - output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) - img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) + output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold) + img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1) return io.NodeOutput(img_out) diff --git a/comfy_extras/nodes_context_windows.py b/comfy_extras/nodes_context_windows.py index 93a5204e1..0e43f2e44 100644 --- a/comfy_extras/nodes_context_windows.py +++ b/comfy_extras/nodes_context_windows.py @@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode): io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), - #io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), - #io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), + io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), + io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), ], outputs=[ io.Model.Output(tooltip="The model with context windows applied during sampling."), diff --git a/comfy_extras/nodes_image_compare.py b/comfy_extras/nodes_image_compare.py index 8e9f809e6..3d943be67 100644 --- a/comfy_extras/nodes_image_compare.py +++ b/comfy_extras/nodes_image_compare.py @@ -14,6 +14,7 @@ class ImageCompare(IO.ComfyNode): display_name="Image Compare", description="Compares two images side by side with a slider.", category="image", + essentials_category="Image Tools", is_experimental=True, is_output_node=True, inputs=[ diff --git a/comfy_extras/nodes_images.py b/comfy_extras/nodes_images.py index 4c57bb5cb..a8223cf8b 100644 --- a/comfy_extras/nodes_images.py +++ b/comfy_extras/nodes_images.py @@ -58,6 +58,7 @@ class ImageCropV2(IO.ComfyNode): search_aliases=["trim"], display_name="Image Crop", category="image/transform", + essentials_category="Image Tools", inputs=[ IO.Image.Input("image"), IO.BoundingBox.Input("crop_region", component="ImageCrop"), diff --git a/comfy_extras/nodes_post_processing.py b/comfy_extras/nodes_post_processing.py index 4a0f7141a..06626f9dd 100644 --- a/comfy_extras/nodes_post_processing.py +++ b/comfy_extras/nodes_post_processing.py @@ -21,6 +21,7 @@ class Blend(io.ComfyNode): node_id="ImageBlend", display_name="Image Blend", category="image/postprocessing", + essentials_category="Image Tools", inputs=[ io.Image.Input("image1"), io.Image.Input("image2"), diff --git a/comfy_extras/nodes_train.py b/comfy_extras/nodes_train.py index aa2d88673..0ad0acee6 100644 --- a/comfy_extras/nodes_train.py +++ b/comfy_extras/nodes_train.py @@ -15,6 +15,7 @@ import comfy.sampler_helpers import comfy.sd import comfy.utils import comfy.model_management +from comfy.cli_args import args, PerformanceFeature import comfy_extras.nodes_custom_sampler import folder_paths import node_helpers @@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler): training_dtype=torch.bfloat16, real_dataset=None, bucket_latents=None, + use_grad_scaler=False, ): self.loss_fn = loss_fn self.optimizer = optimizer @@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler): self.bucket_latents: list[torch.Tensor] | None = ( bucket_latents # list of (Bi, C, Hi, Wi) ) + # GradScaler for fp16 training + self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None # Precompute bucket offsets and weights for sampling if bucket_latents is not None: self._init_bucket_data(bucket_latents) @@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler): batch_sigmas.requires_grad_(True), **batch_extra_args, ) - loss = self.loss_fn(x0_pred, x0) + loss = self.loss_fn(x0_pred.float(), x0.float()) if bwd: bwd_loss = loss / self.grad_acc - bwd_loss.backward() + if self.grad_scaler is not None: + self.grad_scaler.scale(bwd_loss).backward() + else: + bwd_loss.backward() return loss def _generate_batch_sigmas(self, model_wrap, batch_size, device): @@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler): ) total_loss += loss total_loss = total_loss / self.grad_acc / len(indicies) - total_loss.backward() + if self.grad_scaler is not None: + self.grad_scaler.scale(total_loss).backward() + else: + total_loss.backward() if self.loss_callback: self.loss_callback(total_loss.item()) pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) @@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler): self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) if (i + 1) % self.grad_acc == 0: + if self.grad_scaler is not None: + self.grad_scaler.unscale_(self.optimizer) for param_groups in self.optimizer.param_groups: for param in param_groups["params"]: if param.grad is None: continue param.grad.data = param.grad.data.to(param.data.dtype) - self.optimizer.step() + if self.grad_scaler is not None: + self.grad_scaler.step(self.optimizer) + self.grad_scaler.update() + else: + self.optimizer.step() self.optimizer.zero_grad() ui_pbar.update(1) torch.cuda.empty_cache() @@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode): ), io.Combo.Input( "training_dtype", - options=["bf16", "fp32"], + options=["bf16", "fp32", "none"], default="bf16", - tooltip="The dtype to use for training.", + tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.", ), io.Combo.Input( "lora_dtype", @@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode): io.Boolean.Input( "offloading", default=False, - tooltip="Offload the Model to RAM. Requires Bypass Mode.", + tooltip="Offload model weights to CPU during training to save GPU memory.", ), io.Combo.Input( "existing_lora", @@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode): # Setup model and dtype mp = model.clone() - dtype = node_helpers.string_to_torch_dtype(training_dtype) + use_grad_scaler = False + if training_dtype != "none": + dtype = node_helpers.string_to_torch_dtype(training_dtype) + mp.set_model_compute_dtype(dtype) + else: + # Detect model's native dtype for autocast + model_dtype = mp.model.get_dtype() + if model_dtype == torch.float16: + dtype = torch.float16 + use_grad_scaler = True + # Warn about fp16 accumulation instability during training + if PerformanceFeature.Fp16Accumulation in args.fast: + logging.warning( + "WARNING: FP16 model detected with fp16_accumulation enabled. " + "This combination can be numerically unstable during training and may cause NaN values. " + "Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)." + ) + else: + # For fp8, bf16, or other dtypes, use bf16 autocast + dtype = torch.bfloat16 lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) - mp.set_model_compute_dtype(dtype) - - if mp.is_dynamic(): - if not bypass_mode: - logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode") - bypass_mode = True - offloading = True - elif offloading: - if not bypass_mode: - logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message") # Prepare latents and compute counts + latents_dtype = dtype if dtype not in (None,) else torch.bfloat16 latents, num_images, multi_res = _prepare_latents_and_count( - latents, dtype, bucket_mode + latents, latents_dtype, bucket_mode ) # Validate and expand conditioning @@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode): seed=seed, training_dtype=dtype, bucket_latents=latents, + use_grad_scaler=use_grad_scaler, ) else: train_sampler = TrainSampler( @@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode): seed=seed, training_dtype=dtype, real_dataset=latents if multi_res else None, + use_grad_scaler=use_grad_scaler, ) # Setup guider @@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode): io.Int.Input( "steps", optional=True, - tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", + tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.", ), ], outputs=[], diff --git a/comfyui_version.py b/comfyui_version.py index 701f4d66a..a3b7204dc 100644 --- a/comfyui_version.py +++ b/comfyui_version.py @@ -1,3 +1,3 @@ # This file is automatically generated by the build process when version is # updated in pyproject.toml. -__version__ = "0.17.0" +__version__ = "0.18.0" diff --git a/main.py b/main.py index 8905fd09a..f99aee38e 100644 --- a/main.py +++ b/main.py @@ -206,8 +206,8 @@ import hook_breaker_ac10a0 import comfy.memory_management import comfy.model_patcher -if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl(): - if comfy.model_management.torch_version_numeric < (2, 8): +if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()): + if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)): logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): if args.verbose == 'DEBUG': diff --git a/manager_requirements.txt b/manager_requirements.txt index 37a33bd4f..5b06b56f6 100644 --- a/manager_requirements.txt +++ b/manager_requirements.txt @@ -1 +1 @@ -comfyui_manager==4.1b4 \ No newline at end of file +comfyui_manager==4.1b6 \ No newline at end of file diff --git a/nodes.py b/nodes.py index 1e19a8223..2c4650a20 100644 --- a/nodes.py +++ b/nodes.py @@ -81,6 +81,7 @@ class CLIPTextEncode(ComfyNodeABC): class ConditioningCombine: + ESSENTIALS_CATEGORY = "Image Generation" @classmethod def INPUT_TYPES(s): return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}} @@ -951,7 +952,7 @@ class UNETLoader: @classmethod def INPUT_TYPES(s): return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), - "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],) + "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True}) }} RETURN_TYPES = ("MODEL",) FUNCTION = "load_unet" @@ -1211,9 +1212,6 @@ class GLIGENTextBoxApply: return (c, ) class EmptyLatentImage: - def __init__(self): - self.device = comfy.model_management.intermediate_device() - @classmethod def INPUT_TYPES(s): return { @@ -1232,7 +1230,7 @@ class EmptyLatentImage: SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"] def generate(self, width, height, batch_size=1): - latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) + latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()) return ({"samples": latent, "downscale_ratio_spacial": 8}, ) @@ -1781,6 +1779,7 @@ class LoadImage: return True class LoadImageMask: + ESSENTIALS_CATEGORY = "Image Tools" SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] _color_channels = ["alpha", "red", "green", "blue"] @@ -1889,6 +1888,7 @@ class ImageScale: return (s,) class ImageScaleBy: + ESSENTIALS_CATEGORY = "Image Tools" upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] @classmethod @@ -1966,9 +1966,11 @@ class EmptyImage: CATEGORY = "image" def generate(self, width, height, batch_size=1, color=0): - r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF) - g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF) - b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF) + dtype = comfy.model_management.intermediate_dtype() + device = comfy.model_management.intermediate_device() + r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF, device=device, dtype=dtype) + g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF, device=device, dtype=dtype) + b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF, device=device, dtype=dtype) return (torch.cat((r, g, b), dim=-1), ) class ImagePadForOutpaint: diff --git a/pyproject.toml b/pyproject.toml index e2ca79be7..6db9b1267 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "ComfyUI" -version = "0.17.0" +version = "0.18.0" readme = "README.md" license = { file = "LICENSE" } requires-python = ">=3.10" diff --git a/requirements.txt b/requirements.txt index 7e59ef206..ad0344ed4 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -comfyui-frontend-package==1.41.20 -comfyui-workflow-templates==0.9.21 +comfyui-frontend-package==1.41.21 +comfyui-workflow-templates==0.9.26 comfyui-embedded-docs==0.4.3 torch torchsde diff --git a/server.py b/server.py index 85a8964be..173a28376 100644 --- a/server.py +++ b/server.py @@ -35,6 +35,8 @@ from app.frontend_management import FrontendManager, parse_version from comfy_api.internal import _ComfyNodeInternal from app.assets.seeder import asset_seeder from app.assets.api.routes import register_assets_routes +from app.assets.services.ingest import register_file_in_place +from app.assets.services.asset_management import resolve_hash_to_path from app.user_manager import UserManager from app.model_manager import ModelFileManager @@ -419,7 +421,24 @@ class PromptServer(): with open(filepath, "wb") as f: f.write(image.file.read()) - return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) + resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type} + + if args.enable_assets: + try: + tag = image_upload_type if image_upload_type in ("input", "output") else "input" + result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag]) + resp["asset"] = { + "id": result.ref.id, + "name": result.ref.name, + "asset_hash": result.asset.hash, + "size": result.asset.size_bytes, + "mime_type": result.asset.mime_type, + "tags": result.tags, + } + except Exception: + logging.warning("Failed to register uploaded image as asset", exc_info=True) + + return web.json_response(resp) else: return web.Response(status=400) @@ -479,30 +498,43 @@ class PromptServer(): async def view_image(request): if "filename" in request.rel_url.query: filename = request.rel_url.query["filename"] - filename, output_dir = folder_paths.annotated_filepath(filename) - if not filename: - return web.Response(status=400) + # The frontend's LoadImage combo widget uses asset_hash values + # (e.g. "blake3:...") as widget values. When litegraph renders the + # node preview, it constructs /view?filename=, so this + # endpoint must resolve blake3 hashes to their on-disk file paths. + if filename.startswith("blake3:"): + owner_id = self.user_manager.get_request_user_id(request) + result = resolve_hash_to_path(filename, owner_id=owner_id) + if result is None: + return web.Response(status=404) + file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type + else: + resolved_content_type = None + filename, output_dir = folder_paths.annotated_filepath(filename) - # validation for security: prevent accessing arbitrary path - if filename[0] == '/' or '..' in filename: - return web.Response(status=400) + if not filename: + return web.Response(status=400) - if output_dir is None: - type = request.rel_url.query.get("type", "output") - output_dir = folder_paths.get_directory_by_type(type) + # validation for security: prevent accessing arbitrary path + if filename[0] == '/' or '..' in filename: + return web.Response(status=400) - if output_dir is None: - return web.Response(status=400) + if output_dir is None: + type = request.rel_url.query.get("type", "output") + output_dir = folder_paths.get_directory_by_type(type) - if "subfolder" in request.rel_url.query: - full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) - if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: - return web.Response(status=403) - output_dir = full_output_dir + if output_dir is None: + return web.Response(status=400) - filename = os.path.basename(filename) - file = os.path.join(output_dir, filename) + if "subfolder" in request.rel_url.query: + full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) + if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir: + return web.Response(status=403) + output_dir = full_output_dir + + filename = os.path.basename(filename) + file = os.path.join(output_dir, filename) if os.path.isfile(file): if 'preview' in request.rel_url.query: @@ -562,8 +594,13 @@ class PromptServer(): return web.Response(body=alpha_buffer.read(), content_type='image/png', headers={"Content-Disposition": f"filename=\"{filename}\""}) else: - # Get content type from mimetype, defaulting to 'application/octet-stream' - content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream' + # Use the content type from asset resolution if available, + # otherwise guess from the filename. + content_type = ( + resolved_content_type + or mimetypes.guess_type(filename)[0] + or 'application/octet-stream' + ) # For security, force certain mimetypes to download instead of display if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}: diff --git a/tests-unit/app_test/test_migrations.py b/tests-unit/app_test/test_migrations.py new file mode 100644 index 000000000..fa10c1727 --- /dev/null +++ b/tests-unit/app_test/test_migrations.py @@ -0,0 +1,57 @@ +"""Test that Alembic migrations run cleanly on a file-backed SQLite DB. + +This catches problems like unnamed FK constraints that prevent batch-mode +drop_constraint from working on real SQLite files (see MB-2). + +Migrations 0001 and 0002 are already shipped, so we only exercise +upgrade/downgrade for 0003+. +""" + +import os + +import pytest +from alembic import command +from alembic.config import Config + + +# Oldest shipped revision — we upgrade to here as a baseline and never +# downgrade past it. +_BASELINE = "0002_merge_to_asset_references" + + +def _make_config(db_path: str) -> Config: + root = os.path.join(os.path.dirname(__file__), "../..") + config_path = os.path.abspath(os.path.join(root, "alembic.ini")) + scripts_path = os.path.abspath(os.path.join(root, "alembic_db")) + + cfg = Config(config_path) + cfg.set_main_option("script_location", scripts_path) + cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}") + return cfg + + +@pytest.fixture +def migration_db(tmp_path): + """Yield an alembic Config pre-upgraded to the baseline revision.""" + db_path = str(tmp_path / "test_migration.db") + cfg = _make_config(db_path) + command.upgrade(cfg, _BASELINE) + yield cfg + + +def test_upgrade_to_head(migration_db): + """Upgrade from baseline to head must succeed on a file-backed DB.""" + command.upgrade(migration_db, "head") + + +def test_downgrade_to_baseline(migration_db): + """Upgrade to head then downgrade back to baseline.""" + command.upgrade(migration_db, "head") + command.downgrade(migration_db, _BASELINE) + + +def test_upgrade_downgrade_cycle(migration_db): + """Full cycle: upgrade → downgrade → upgrade again.""" + command.upgrade(migration_db, "head") + command.downgrade(migration_db, _BASELINE) + command.upgrade(migration_db, "head") diff --git a/tests-unit/assets_test/queries/test_asset.py b/tests-unit/assets_test/queries/test_asset.py index 08f84cd11..9b7eb4bac 100644 --- a/tests-unit/assets_test/queries/test_asset.py +++ b/tests-unit/assets_test/queries/test_asset.py @@ -10,6 +10,7 @@ from app.assets.database.queries import ( get_asset_by_hash, upsert_asset, bulk_insert_assets, + update_asset_hash_and_mime, ) @@ -142,3 +143,45 @@ class TestBulkInsertAssets: session.commit() assert session.query(Asset).count() == 200 + + +class TestMimeTypeImmutability: + """mime_type on Asset is write-once: set on first ingest, never overwritten.""" + + @pytest.mark.parametrize( + "initial_mime,second_mime,expected_mime", + [ + ("image/png", "image/jpeg", "image/png"), + (None, "image/png", "image/png"), + ], + ids=["preserves_existing", "fills_null"], + ) + def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime): + h = f"blake3:upsert_{initial_mime}_{second_mime}" + upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime) + session.commit() + + asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime) + assert created is False + assert asset.mime_type == expected_mime + + @pytest.mark.parametrize( + "initial_mime,update_mime,update_hash,expected_mime,expected_hash", + [ + (None, "image/png", None, "image/png", "blake3:upd0"), + ("image/png", "image/jpeg", None, "image/png", "blake3:upd1"), + ("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"), + ], + ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"], + ) + def test_update_asset_hash_and_mime_immutability( + self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash, + ): + h = expected_hash.removesuffix("_new") + asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime) + session.add(asset) + session.flush() + + update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash) + assert asset.mime_type == expected_mime + assert asset.hash == expected_hash diff --git a/tests-unit/assets_test/queries/test_asset_info.py b/tests-unit/assets_test/queries/test_asset_info.py index 8f6c7fcdb..fe510e342 100644 --- a/tests-unit/assets_test/queries/test_asset_info.py +++ b/tests-unit/assets_test/queries/test_asset_info.py @@ -242,22 +242,24 @@ class TestSetReferencePreview: asset = _make_asset(session, "hash1") preview_asset = _make_asset(session, "preview_hash") ref = _make_reference(session, asset) + preview_ref = _make_reference(session, preview_asset, name="preview.png") session.commit() - set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id) + set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id) session.commit() session.refresh(ref) - assert ref.preview_id == preview_asset.id + assert ref.preview_id == preview_ref.id def test_clears_preview(self, session: Session): asset = _make_asset(session, "hash1") preview_asset = _make_asset(session, "preview_hash") ref = _make_reference(session, asset) - ref.preview_id = preview_asset.id + preview_ref = _make_reference(session, preview_asset, name="preview.png") + ref.preview_id = preview_ref.id session.commit() - set_reference_preview(session, reference_id=ref.id, preview_asset_id=None) + set_reference_preview(session, reference_id=ref.id, preview_reference_id=None) session.commit() session.refresh(ref) @@ -265,15 +267,15 @@ class TestSetReferencePreview: def test_raises_for_nonexistent_reference(self, session: Session): with pytest.raises(ValueError, match="not found"): - set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None) + set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None) def test_raises_for_nonexistent_preview(self, session: Session): asset = _make_asset(session, "hash1") ref = _make_reference(session, asset) session.commit() - with pytest.raises(ValueError, match="Preview Asset"): - set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent") + with pytest.raises(ValueError, match="Preview AssetReference"): + set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent") class TestInsertReference: @@ -351,13 +353,14 @@ class TestUpdateReferenceTimestamps: asset = _make_asset(session, "hash1") preview_asset = _make_asset(session, "preview_hash") ref = _make_reference(session, asset) + preview_ref = _make_reference(session, preview_asset, name="preview.png") session.commit() - update_reference_timestamps(session, ref, preview_id=preview_asset.id) + update_reference_timestamps(session, ref, preview_id=preview_ref.id) session.commit() session.refresh(ref) - assert ref.preview_id == preview_asset.id + assert ref.preview_id == preview_ref.id class TestSetReferenceMetadata: diff --git a/tests-unit/assets_test/queries/test_metadata.py b/tests-unit/assets_test/queries/test_metadata.py index 6a545e819..d7a747789 100644 --- a/tests-unit/assets_test/queries/test_metadata.py +++ b/tests-unit/assets_test/queries/test_metadata.py @@ -20,6 +20,7 @@ def _make_reference( asset: Asset, name: str, metadata: dict | None = None, + system_metadata: dict | None = None, ) -> AssetReference: now = get_utc_now() ref = AssetReference( @@ -27,6 +28,7 @@ def _make_reference( name=name, asset_id=asset.id, user_metadata=metadata, + system_metadata=system_metadata, created_at=now, updated_at=now, last_access_time=now, @@ -34,8 +36,10 @@ def _make_reference( session.add(ref) session.flush() - if metadata: - for key, val in metadata.items(): + # Build merged projection: {**system_metadata, **user_metadata} + merged = {**(system_metadata or {}), **(metadata or {})} + if merged: + for key, val in merged.items(): for row in convert_metadata_to_rows(key, val): meta_row = AssetReferenceMeta( asset_reference_id=ref.id, @@ -182,3 +186,46 @@ class TestMetadataFilterEmptyDict: refs, _, total = list_references_page(session, metadata_filter={}) assert total == 2 + + +class TestSystemMetadataProjection: + """Tests for system_metadata merging into the filter projection.""" + + def test_system_metadata_keys_are_filterable(self, session: Session): + """system_metadata keys should appear in the merged projection.""" + asset = _make_asset(session, "hash1") + _make_reference( + session, asset, "with_sys", + system_metadata={"source": "scanner"}, + ) + _make_reference(session, asset, "without_sys") + session.commit() + + refs, _, total = list_references_page( + session, metadata_filter={"source": "scanner"} + ) + assert total == 1 + assert refs[0].name == "with_sys" + + def test_user_metadata_overrides_system_metadata(self, session: Session): + """user_metadata should win when both have the same key.""" + asset = _make_asset(session, "hash1") + _make_reference( + session, asset, "overridden", + metadata={"origin": "user_upload"}, + system_metadata={"origin": "auto_scan"}, + ) + session.commit() + + # Should match the user value, not the system value + refs, _, total = list_references_page( + session, metadata_filter={"origin": "user_upload"} + ) + assert total == 1 + assert refs[0].name == "overridden" + + # Should NOT match the system value (it was overridden) + refs, _, total = list_references_page( + session, metadata_filter={"origin": "auto_scan"} + ) + assert total == 0 diff --git a/tests-unit/assets_test/services/test_asset_management.py b/tests-unit/assets_test/services/test_asset_management.py index 101ef7292..e8ff989e9 100644 --- a/tests-unit/assets_test/services/test_asset_management.py +++ b/tests-unit/assets_test/services/test_asset_management.py @@ -11,6 +11,7 @@ from app.assets.services import ( delete_asset_reference, set_asset_preview, ) +from app.assets.services.asset_management import resolve_hash_to_path def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset: @@ -219,31 +220,33 @@ class TestSetAssetPreview: asset = _make_asset(session, hash_val="blake3:main") preview_asset = _make_asset(session, hash_val="blake3:preview") ref = _make_reference(session, asset) + preview_ref = _make_reference(session, preview_asset, name="preview.png") ref_id = ref.id - preview_id = preview_asset.id + preview_ref_id = preview_ref.id session.commit() set_asset_preview( reference_id=ref_id, - preview_asset_id=preview_id, + preview_reference_id=preview_ref_id, ) # Verify by re-fetching from DB session.expire_all() updated_ref = session.get(AssetReference, ref_id) - assert updated_ref.preview_id == preview_id + assert updated_ref.preview_id == preview_ref_id def test_clears_preview(self, mock_create_session, session: Session): asset = _make_asset(session) preview_asset = _make_asset(session, hash_val="blake3:preview") ref = _make_reference(session, asset) - ref.preview_id = preview_asset.id + preview_ref = _make_reference(session, preview_asset, name="preview.png") + ref.preview_id = preview_ref.id ref_id = ref.id session.commit() set_asset_preview( reference_id=ref_id, - preview_asset_id=None, + preview_reference_id=None, ) # Verify by re-fetching from DB @@ -263,6 +266,45 @@ class TestSetAssetPreview: with pytest.raises(PermissionError, match="not owner"): set_asset_preview( reference_id=ref.id, - preview_asset_id=None, + preview_reference_id=None, owner_id="user2", ) + + +class TestResolveHashToPath: + def test_returns_none_for_unknown_hash(self, mock_create_session): + result = resolve_hash_to_path("blake3:" + "a" * 64) + assert result is None + + @pytest.mark.parametrize( + "ref_owner, query_owner, expect_found", + [ + ("user1", "user1", True), + ("user1", "user2", False), + ("", "anyone", True), + ("", "", True), + ], + ids=[ + "owner_sees_own_ref", + "other_owner_blocked", + "ownerless_visible_to_anyone", + "ownerless_visible_to_empty", + ], + ) + def test_owner_visibility( + self, ref_owner, query_owner, expect_found, + mock_create_session, session: Session, temp_dir, + ): + f = temp_dir / "file.bin" + f.write_bytes(b"data") + asset = _make_asset(session, hash_val="blake3:" + "b" * 64) + ref = _make_reference(session, asset, name="file.bin", owner_id=ref_owner) + ref.file_path = str(f) + session.commit() + + result = resolve_hash_to_path(asset.hash, owner_id=query_owner) + if expect_found: + assert result is not None + assert result.abs_path == str(f) + else: + assert result is None diff --git a/tests-unit/assets_test/services/test_ingest.py b/tests-unit/assets_test/services/test_ingest.py index 367bc7721..dbb8441c2 100644 --- a/tests-unit/assets_test/services/test_ingest.py +++ b/tests-unit/assets_test/services/test_ingest.py @@ -113,11 +113,19 @@ class TestIngestFileFromPath: file_path = temp_dir / "with_preview.bin" file_path.write_bytes(b"data") - # Create a preview asset first + # Create a preview asset and reference preview_asset = Asset(hash="blake3:preview", size_bytes=100) session.add(preview_asset) + session.flush() + from app.assets.helpers import get_utc_now + now = get_utc_now() + preview_ref = AssetReference( + asset_id=preview_asset.id, name="preview.png", owner_id="", + created_at=now, updated_at=now, last_access_time=now, + ) + session.add(preview_ref) session.commit() - preview_id = preview_asset.id + preview_id = preview_ref.id result = _ingest_file_from_path( abs_path=str(file_path), diff --git a/tests-unit/assets_test/services/test_tag_histogram.py b/tests-unit/assets_test/services/test_tag_histogram.py new file mode 100644 index 000000000..7bcd518ec --- /dev/null +++ b/tests-unit/assets_test/services/test_tag_histogram.py @@ -0,0 +1,123 @@ +"""Tests for list_tag_histogram service function.""" +from sqlalchemy.orm import Session + +from app.assets.database.models import Asset, AssetReference +from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference +from app.assets.helpers import get_utc_now +from app.assets.services.tagging import list_tag_histogram + + +def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset: + asset = Asset(hash=hash_val, size_bytes=1024) + session.add(asset) + session.flush() + return asset + + +def _make_reference( + session: Session, + asset: Asset, + name: str = "test", + owner_id: str = "", +) -> AssetReference: + now = get_utc_now() + ref = AssetReference( + owner_id=owner_id, + name=name, + asset_id=asset.id, + created_at=now, + updated_at=now, + last_access_time=now, + ) + session.add(ref) + session.flush() + return ref + + +class TestListTagHistogram: + def test_returns_counts_for_all_tags(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["alpha", "beta"]) + a1 = _make_asset(session, "blake3:aaa") + r1 = _make_reference(session, a1, name="r1") + add_tags_to_reference(session, reference_id=r1.id, tags=["alpha", "beta"]) + + a2 = _make_asset(session, "blake3:bbb") + r2 = _make_reference(session, a2, name="r2") + add_tags_to_reference(session, reference_id=r2.id, tags=["alpha"]) + session.commit() + + result = list_tag_histogram() + + assert result["alpha"] == 2 + assert result["beta"] == 1 + + def test_empty_when_no_assets(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["unused"]) + session.commit() + + result = list_tag_histogram() + + assert result == {} + + def test_include_tags_filter(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["models", "loras", "input"]) + a1 = _make_asset(session, "blake3:aaa") + r1 = _make_reference(session, a1, name="r1") + add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"]) + + a2 = _make_asset(session, "blake3:bbb") + r2 = _make_reference(session, a2, name="r2") + add_tags_to_reference(session, reference_id=r2.id, tags=["input"]) + session.commit() + + result = list_tag_histogram(include_tags=["models"]) + + # Only r1 has "models", so only its tags appear + assert "models" in result + assert "loras" in result + assert "input" not in result + + def test_exclude_tags_filter(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["models", "loras", "input"]) + a1 = _make_asset(session, "blake3:aaa") + r1 = _make_reference(session, a1, name="r1") + add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"]) + + a2 = _make_asset(session, "blake3:bbb") + r2 = _make_reference(session, a2, name="r2") + add_tags_to_reference(session, reference_id=r2.id, tags=["input"]) + session.commit() + + result = list_tag_histogram(exclude_tags=["models"]) + + # r1 excluded, only r2's tags remain + assert "input" in result + assert "loras" not in result + + def test_name_contains_filter(self, mock_create_session, session: Session): + ensure_tags_exist(session, ["alpha", "beta"]) + a1 = _make_asset(session, "blake3:aaa") + r1 = _make_reference(session, a1, name="my_model.safetensors") + add_tags_to_reference(session, reference_id=r1.id, tags=["alpha"]) + + a2 = _make_asset(session, "blake3:bbb") + r2 = _make_reference(session, a2, name="picture.png") + add_tags_to_reference(session, reference_id=r2.id, tags=["beta"]) + session.commit() + + result = list_tag_histogram(name_contains="model") + + assert "alpha" in result + assert "beta" not in result + + def test_limit_caps_results(self, mock_create_session, session: Session): + tags = [f"tag{i}" for i in range(10)] + ensure_tags_exist(session, tags) + a = _make_asset(session, "blake3:aaa") + r = _make_reference(session, a, name="r1") + add_tags_to_reference(session, reference_id=r.id, tags=tags) + session.commit() + + result = list_tag_histogram(limit=3) + + assert len(result) == 3 diff --git a/tests-unit/assets_test/test_uploads.py b/tests-unit/assets_test/test_uploads.py index d68e5b5d7..0f2b124a3 100644 --- a/tests-unit/assets_test/test_uploads.py +++ b/tests-unit/assets_test/test_uploads.py @@ -243,6 +243,15 @@ def test_upload_tags_traversal_guard(http: requests.Session, api_base: str): assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") +def test_upload_empty_tags_rejected(http: requests.Session, api_base: str): + files = {"file": ("notags.bin", b"A" * 64, "application/octet-stream")} + form = {"tags": json.dumps([]), "name": "notags.bin", "user_metadata": json.dumps({})} + r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120) + body = r.json() + assert r.status_code == 400 + assert body["error"]["code"] == "INVALID_BODY" + + @pytest.mark.parametrize("root", ["input", "output"]) def test_duplicate_upload_same_display_name_does_not_clobber( root: str,