mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-04-15 04:52:31 +08:00
Merge branch 'master' into deepme987/auto-register-node-replacements-json
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
Some checks are pending
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.10) (push) Waiting to run
Build package / Build Test (3.11) (push) Waiting to run
Build package / Build Test (3.12) (push) Waiting to run
Build package / Build Test (3.13) (push) Waiting to run
Build package / Build Test (3.14) (push) Waiting to run
This commit is contained in:
commit
b9b24d425b
103
.github/scripts/check-ai-co-authors.sh
vendored
Executable file
103
.github/scripts/check-ai-co-authors.sh
vendored
Executable file
@ -0,0 +1,103 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Checks pull request commits for AI agent Co-authored-by trailers.
|
||||||
|
# Exits non-zero when any are found and prints fix instructions.
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
base_sha="${1:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||||
|
head_sha="${2:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||||
|
|
||||||
|
# Known AI coding-agent trailer patterns (case-insensitive).
|
||||||
|
# Each entry is an extended-regex fragment matched against Co-authored-by lines.
|
||||||
|
AGENT_PATTERNS=(
|
||||||
|
# Anthropic — Claude Code / Amp
|
||||||
|
'noreply@anthropic\.com'
|
||||||
|
# Cursor
|
||||||
|
'cursoragent@cursor\.com'
|
||||||
|
# GitHub Copilot
|
||||||
|
'copilot-swe-agent\[bot\]'
|
||||||
|
'copilot@github\.com'
|
||||||
|
# OpenAI Codex
|
||||||
|
'noreply@openai\.com'
|
||||||
|
'codex@openai\.com'
|
||||||
|
# Aider
|
||||||
|
'aider@aider\.chat'
|
||||||
|
# Google — Gemini / Jules
|
||||||
|
'gemini@google\.com'
|
||||||
|
'jules@google\.com'
|
||||||
|
# Windsurf / Codeium
|
||||||
|
'@codeium\.com'
|
||||||
|
# Devin
|
||||||
|
'devin-ai-integration\[bot\]'
|
||||||
|
'devin@cognition\.ai'
|
||||||
|
'devin@cognition-labs\.com'
|
||||||
|
# Amazon Q Developer
|
||||||
|
'amazon-q-developer'
|
||||||
|
'@amazon\.com.*[Qq].[Dd]eveloper'
|
||||||
|
# Cline
|
||||||
|
'cline-bot'
|
||||||
|
'cline@cline\.ai'
|
||||||
|
# Continue
|
||||||
|
'continue-agent'
|
||||||
|
'continue@continue\.dev'
|
||||||
|
# Sourcegraph
|
||||||
|
'noreply@sourcegraph\.com'
|
||||||
|
# Generic catch-alls for common agent name patterns
|
||||||
|
'Co-authored-by:.*\b[Cc]laude\b'
|
||||||
|
'Co-authored-by:.*\b[Cc]opilot\b'
|
||||||
|
'Co-authored-by:.*\b[Cc]ursor\b'
|
||||||
|
'Co-authored-by:.*\b[Cc]odex\b'
|
||||||
|
'Co-authored-by:.*\b[Gg]emini\b'
|
||||||
|
'Co-authored-by:.*\b[Aa]ider\b'
|
||||||
|
'Co-authored-by:.*\b[Dd]evin\b'
|
||||||
|
'Co-authored-by:.*\b[Ww]indsurf\b'
|
||||||
|
'Co-authored-by:.*\b[Cc]line\b'
|
||||||
|
'Co-authored-by:.*\b[Aa]mazon Q\b'
|
||||||
|
'Co-authored-by:.*\b[Jj]ules\b'
|
||||||
|
'Co-authored-by:.*\bOpenCode\b'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build a single alternation regex from all patterns.
|
||||||
|
regex=""
|
||||||
|
for pattern in "${AGENT_PATTERNS[@]}"; do
|
||||||
|
if [[ -n "$regex" ]]; then
|
||||||
|
regex="${regex}|${pattern}"
|
||||||
|
else
|
||||||
|
regex="$pattern"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Collect Co-authored-by lines from every commit in the PR range.
|
||||||
|
violations=""
|
||||||
|
while IFS= read -r sha; do
|
||||||
|
message="$(git log -1 --format='%B' "$sha")"
|
||||||
|
matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
|
||||||
|
if [[ -z "$matched_lines" ]]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
while IFS= read -r line; do
|
||||||
|
if echo "$line" | grep -iqE "$regex"; then
|
||||||
|
short="$(git log -1 --format='%h' "$sha")"
|
||||||
|
violations="${violations} ${short}: ${line}"$'\n'
|
||||||
|
fi
|
||||||
|
done <<< "$matched_lines"
|
||||||
|
done < <(git rev-list "${base_sha}..${head_sha}")
|
||||||
|
|
||||||
|
if [[ -n "$violations" ]]; then
|
||||||
|
echo "::error::AI agent Co-authored-by trailers detected in PR commits."
|
||||||
|
echo ""
|
||||||
|
echo "The following commits contain Co-authored-by trailers from AI coding agents:"
|
||||||
|
echo ""
|
||||||
|
echo "$violations"
|
||||||
|
echo "These trailers should be removed before merging."
|
||||||
|
echo ""
|
||||||
|
echo "To fix, rewrite the commit messages with:"
|
||||||
|
echo " git rebase -i ${base_sha}"
|
||||||
|
echo ""
|
||||||
|
echo "and remove the Co-authored-by lines, then force-push your branch."
|
||||||
|
echo ""
|
||||||
|
echo "If you believe this is a false positive, please open an issue."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "No AI agent Co-authored-by trailers found."
|
||||||
19
.github/workflows/check-ai-co-authors.yml
vendored
Normal file
19
.github/workflows/check-ai-co-authors.yml
vendored
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
name: Check AI Co-Authors
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: ['*']
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-ai-co-authors:
|
||||||
|
name: Check for AI agent co-author trailers
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Check commits for AI co-author trailers
|
||||||
|
run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"
|
||||||
@ -8,7 +8,7 @@ from alembic import context
|
|||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
|
|
||||||
from app.database.models import Base
|
from app.database.models import Base, NAMING_CONVENTION
|
||||||
target_metadata = Base.metadata
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
# other values from the config, defined by the needs of env.py,
|
||||||
@ -51,7 +51,10 @@ def run_migrations_online() -> None:
|
|||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(
|
context.configure(
|
||||||
connection=connection, target_metadata=target_metadata
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
render_as_batch=True,
|
||||||
|
naming_convention=NAMING_CONVENTION,
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
|
|||||||
98
alembic_db/versions/0003_add_metadata_job_id.py
Normal file
98
alembic_db/versions/0003_add_metadata_job_id.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Add system_metadata and job_id columns to asset_references.
|
||||||
|
Change preview_id FK from assets.id to asset_references.id.
|
||||||
|
|
||||||
|
Revision ID: 0003_add_metadata_job_id
|
||||||
|
Revises: 0002_merge_to_asset_references
|
||||||
|
Create Date: 2026-03-09
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from app.database.models import NAMING_CONVENTION
|
||||||
|
|
||||||
|
revision = "0003_add_metadata_job_id"
|
||||||
|
down_revision = "0002_merge_to_asset_references"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
with op.batch_alter_table("asset_references") as batch_op:
|
||||||
|
batch_op.add_column(
|
||||||
|
sa.Column("system_metadata", sa.JSON(), nullable=True)
|
||||||
|
)
|
||||||
|
batch_op.add_column(
|
||||||
|
sa.Column("job_id", sa.String(length=36), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Change preview_id FK from assets.id to asset_references.id (self-ref).
|
||||||
|
# Existing values are asset-content IDs that won't match reference IDs,
|
||||||
|
# so null them out first.
|
||||||
|
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
|
||||||
|
with op.batch_alter_table(
|
||||||
|
"asset_references", naming_convention=NAMING_CONVENTION
|
||||||
|
) as batch_op:
|
||||||
|
batch_op.drop_constraint(
|
||||||
|
"fk_asset_references_preview_id_assets", type_="foreignkey"
|
||||||
|
)
|
||||||
|
batch_op.create_foreign_key(
|
||||||
|
"fk_asset_references_preview_id_asset_references",
|
||||||
|
"asset_references",
|
||||||
|
["preview_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
batch_op.create_index(
|
||||||
|
"ix_asset_references_preview_id", ["preview_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Purge any all-null meta rows before adding the constraint
|
||||||
|
op.execute(
|
||||||
|
"DELETE FROM asset_reference_meta"
|
||||||
|
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
|
||||||
|
)
|
||||||
|
with op.batch_alter_table("asset_reference_meta") as batch_op:
|
||||||
|
batch_op.create_check_constraint(
|
||||||
|
"ck_asset_reference_meta_has_value",
|
||||||
|
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# SQLite doesn't reflect CHECK constraints, so we must declare it
|
||||||
|
# explicitly via table_args for the batch recreate to find it.
|
||||||
|
# Use the fully-rendered constraint name to avoid the naming convention
|
||||||
|
# doubling the prefix.
|
||||||
|
with op.batch_alter_table(
|
||||||
|
"asset_reference_meta",
|
||||||
|
table_args=[
|
||||||
|
sa.CheckConstraint(
|
||||||
|
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||||
|
name="ck_asset_reference_meta_has_value",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) as batch_op:
|
||||||
|
batch_op.drop_constraint(
|
||||||
|
"ck_asset_reference_meta_has_value", type_="check"
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table(
|
||||||
|
"asset_references", naming_convention=NAMING_CONVENTION
|
||||||
|
) as batch_op:
|
||||||
|
batch_op.drop_index("ix_asset_references_preview_id")
|
||||||
|
batch_op.drop_constraint(
|
||||||
|
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
|
||||||
|
)
|
||||||
|
batch_op.create_foreign_key(
|
||||||
|
"fk_asset_references_preview_id_assets",
|
||||||
|
"assets",
|
||||||
|
["preview_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("asset_references") as batch_op:
|
||||||
|
batch_op.drop_column("job_id")
|
||||||
|
batch_op.drop_column("system_metadata")
|
||||||
@ -13,6 +13,7 @@ from pydantic import ValidationError
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
from app import user_manager
|
from app import user_manager
|
||||||
from app.assets.api import schemas_in, schemas_out
|
from app.assets.api import schemas_in, schemas_out
|
||||||
|
from app.assets.services import schemas
|
||||||
from app.assets.api.schemas_in import (
|
from app.assets.api.schemas_in import (
|
||||||
AssetValidationError,
|
AssetValidationError,
|
||||||
UploadError,
|
UploadError,
|
||||||
@ -38,6 +39,7 @@ from app.assets.services import (
|
|||||||
update_asset_metadata,
|
update_asset_metadata,
|
||||||
upload_from_temp_path,
|
upload_from_temp_path,
|
||||||
)
|
)
|
||||||
|
from app.assets.services.tagging import list_tag_histogram
|
||||||
|
|
||||||
ROUTES = web.RouteTableDef()
|
ROUTES = web.RouteTableDef()
|
||||||
USER_MANAGER: user_manager.UserManager | None = None
|
USER_MANAGER: user_manager.UserManager | None = None
|
||||||
@ -122,6 +124,61 @@ def _validate_sort_field(requested: str | None) -> str:
|
|||||||
return "created_at"
|
return "created_at"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None:
|
||||||
|
"""Build a /api/view preview URL from asset tags and user_metadata filename."""
|
||||||
|
if not user_metadata:
|
||||||
|
return None
|
||||||
|
filename = user_metadata.get("filename")
|
||||||
|
if not filename:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "input" in tags:
|
||||||
|
view_type = "input"
|
||||||
|
elif "output" in tags:
|
||||||
|
view_type = "output"
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
subfolder = ""
|
||||||
|
if "/" in filename:
|
||||||
|
subfolder, filename = filename.rsplit("/", 1)
|
||||||
|
|
||||||
|
encoded_filename = urllib.parse.quote(filename, safe="")
|
||||||
|
url = f"/api/view?type={view_type}&filename={encoded_filename}"
|
||||||
|
if subfolder:
|
||||||
|
url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}"
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
|
||||||
|
"""Build an Asset response from a service result."""
|
||||||
|
if result.ref.preview_id:
|
||||||
|
preview_detail = get_asset_detail(result.ref.preview_id)
|
||||||
|
if preview_detail:
|
||||||
|
preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata)
|
||||||
|
else:
|
||||||
|
preview_url = None
|
||||||
|
else:
|
||||||
|
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
|
||||||
|
return schemas_out.Asset(
|
||||||
|
id=result.ref.id,
|
||||||
|
name=result.ref.name,
|
||||||
|
asset_hash=result.asset.hash if result.asset else None,
|
||||||
|
size=int(result.asset.size_bytes) if result.asset else None,
|
||||||
|
mime_type=result.asset.mime_type if result.asset else None,
|
||||||
|
tags=result.tags,
|
||||||
|
preview_url=preview_url,
|
||||||
|
preview_id=result.ref.preview_id,
|
||||||
|
user_metadata=result.ref.user_metadata or {},
|
||||||
|
metadata=result.ref.system_metadata,
|
||||||
|
job_id=result.ref.job_id,
|
||||||
|
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
|
||||||
|
created_at=result.ref.created_at,
|
||||||
|
updated_at=result.ref.updated_at,
|
||||||
|
last_access_time=result.ref.last_access_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.head("/api/assets/hash/{hash}")
|
@ROUTES.head("/api/assets/hash/{hash}")
|
||||||
@_require_assets_feature_enabled
|
@_require_assets_feature_enabled
|
||||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||||
@ -164,20 +221,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
|||||||
order=order,
|
order=order,
|
||||||
)
|
)
|
||||||
|
|
||||||
summaries = [
|
summaries = [_build_asset_response(item) for item in result.items]
|
||||||
schemas_out.AssetSummary(
|
|
||||||
id=item.ref.id,
|
|
||||||
name=item.ref.name,
|
|
||||||
asset_hash=item.asset.hash if item.asset else None,
|
|
||||||
size=int(item.asset.size_bytes) if item.asset else None,
|
|
||||||
mime_type=item.asset.mime_type if item.asset else None,
|
|
||||||
tags=item.tags,
|
|
||||||
created_at=item.ref.created_at,
|
|
||||||
updated_at=item.ref.updated_at,
|
|
||||||
last_access_time=item.ref.last_access_time,
|
|
||||||
)
|
|
||||||
for item in result.items
|
|
||||||
]
|
|
||||||
|
|
||||||
payload = schemas_out.AssetsList(
|
payload = schemas_out.AssetsList(
|
||||||
assets=summaries,
|
assets=summaries,
|
||||||
@ -207,18 +251,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
|||||||
{"id": reference_id},
|
{"id": reference_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = schemas_out.AssetDetail(
|
payload = _build_asset_response(result)
|
||||||
id=result.ref.id,
|
|
||||||
name=result.ref.name,
|
|
||||||
asset_hash=result.asset.hash if result.asset else None,
|
|
||||||
size=int(result.asset.size_bytes) if result.asset else None,
|
|
||||||
mime_type=result.asset.mime_type if result.asset else None,
|
|
||||||
tags=result.tags,
|
|
||||||
user_metadata=result.ref.user_metadata or {},
|
|
||||||
preview_id=result.ref.preview_id,
|
|
||||||
created_at=result.ref.created_at,
|
|
||||||
last_access_time=result.ref.last_access_time,
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return _build_error_response(
|
return _build_error_response(
|
||||||
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
|
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
|
||||||
@ -230,7 +263,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
|||||||
USER_MANAGER.get_request_user_id(request),
|
USER_MANAGER.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||||
@ -312,32 +345,31 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
|
|||||||
400, "INVALID_JSON", "Request body must be valid JSON."
|
400, "INVALID_JSON", "Request body must be valid JSON."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Derive name from hash if not provided
|
||||||
|
name = body.name
|
||||||
|
if name is None:
|
||||||
|
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
|
||||||
|
|
||||||
result = create_from_hash(
|
result = create_from_hash(
|
||||||
hash_str=body.hash,
|
hash_str=body.hash,
|
||||||
name=body.name,
|
name=name,
|
||||||
tags=body.tags,
|
tags=body.tags,
|
||||||
user_metadata=body.user_metadata,
|
user_metadata=body.user_metadata,
|
||||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
mime_type=body.mime_type,
|
||||||
|
preview_id=body.preview_id,
|
||||||
)
|
)
|
||||||
if result is None:
|
if result is None:
|
||||||
return _build_error_response(
|
return _build_error_response(
|
||||||
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
|
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
asset = _build_asset_response(result)
|
||||||
payload_out = schemas_out.AssetCreated(
|
payload_out = schemas_out.AssetCreated(
|
||||||
id=result.ref.id,
|
**asset.model_dump(),
|
||||||
name=result.ref.name,
|
|
||||||
asset_hash=result.asset.hash,
|
|
||||||
size=int(result.asset.size_bytes),
|
|
||||||
mime_type=result.asset.mime_type,
|
|
||||||
tags=result.tags,
|
|
||||||
user_metadata=result.ref.user_metadata or {},
|
|
||||||
preview_id=result.ref.preview_id,
|
|
||||||
created_at=result.ref.created_at,
|
|
||||||
last_access_time=result.ref.last_access_time,
|
|
||||||
created_new=result.created_new,
|
created_new=result.created_new,
|
||||||
)
|
)
|
||||||
return web.json_response(payload_out.model_dump(mode="json"), status=201)
|
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.post("/api/assets")
|
@ROUTES.post("/api/assets")
|
||||||
@ -358,6 +390,8 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
"name": parsed.provided_name,
|
"name": parsed.provided_name,
|
||||||
"user_metadata": parsed.user_metadata_raw,
|
"user_metadata": parsed.user_metadata_raw,
|
||||||
"hash": parsed.provided_hash,
|
"hash": parsed.provided_hash,
|
||||||
|
"mime_type": parsed.provided_mime_type,
|
||||||
|
"preview_id": parsed.provided_preview_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except ValidationError as ve:
|
except ValidationError as ve:
|
||||||
@ -386,6 +420,8 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
tags=spec.tags,
|
tags=spec.tags,
|
||||||
user_metadata=spec.user_metadata or {},
|
user_metadata=spec.user_metadata or {},
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
|
mime_type=spec.mime_type,
|
||||||
|
preview_id=spec.preview_id,
|
||||||
)
|
)
|
||||||
if result is None:
|
if result is None:
|
||||||
delete_temp_file_if_exists(parsed.tmp_path)
|
delete_temp_file_if_exists(parsed.tmp_path)
|
||||||
@ -410,6 +446,8 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
client_filename=parsed.file_client_name,
|
client_filename=parsed.file_client_name,
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
expected_hash=spec.hash,
|
expected_hash=spec.hash,
|
||||||
|
mime_type=spec.mime_type,
|
||||||
|
preview_id=spec.preview_id,
|
||||||
)
|
)
|
||||||
except AssetValidationError as e:
|
except AssetValidationError as e:
|
||||||
delete_temp_file_if_exists(parsed.tmp_path)
|
delete_temp_file_if_exists(parsed.tmp_path)
|
||||||
@ -428,21 +466,13 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
logging.exception("upload_asset failed for owner_id=%s", owner_id)
|
logging.exception("upload_asset failed for owner_id=%s", owner_id)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
payload = schemas_out.AssetCreated(
|
asset = _build_asset_response(result)
|
||||||
id=result.ref.id,
|
payload_out = schemas_out.AssetCreated(
|
||||||
name=result.ref.name,
|
**asset.model_dump(),
|
||||||
asset_hash=result.asset.hash,
|
|
||||||
size=int(result.asset.size_bytes),
|
|
||||||
mime_type=result.asset.mime_type,
|
|
||||||
tags=result.tags,
|
|
||||||
user_metadata=result.ref.user_metadata or {},
|
|
||||||
preview_id=result.ref.preview_id,
|
|
||||||
created_at=result.ref.created_at,
|
|
||||||
last_access_time=result.ref.last_access_time,
|
|
||||||
created_new=result.created_new,
|
created_new=result.created_new,
|
||||||
)
|
)
|
||||||
status = 201 if result.created_new else 200
|
status = 201 if result.created_new else 200
|
||||||
return web.json_response(payload.model_dump(mode="json"), status=status)
|
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
@ -464,15 +494,9 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
|||||||
name=body.name,
|
name=body.name,
|
||||||
user_metadata=body.user_metadata,
|
user_metadata=body.user_metadata,
|
||||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
preview_id=body.preview_id,
|
||||||
)
|
)
|
||||||
payload = schemas_out.AssetUpdated(
|
payload = _build_asset_response(result)
|
||||||
id=result.ref.id,
|
|
||||||
name=result.ref.name,
|
|
||||||
asset_hash=result.asset.hash if result.asset else None,
|
|
||||||
tags=result.tags,
|
|
||||||
user_metadata=result.ref.user_metadata or {},
|
|
||||||
updated_at=result.ref.updated_at,
|
|
||||||
)
|
|
||||||
except PermissionError as pe:
|
except PermissionError as pe:
|
||||||
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
|
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
@ -486,7 +510,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
|||||||
USER_MANAGER.get_request_user_id(request),
|
USER_MANAGER.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
@ -555,7 +579,7 @@ async def get_tags(request: web.Request) -> web.Response:
|
|||||||
payload = schemas_out.TagsList(
|
payload = schemas_out.TagsList(
|
||||||
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
||||||
)
|
)
|
||||||
return web.json_response(payload.model_dump(mode="json"))
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||||
@ -603,7 +627,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
|
|||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||||
@ -650,7 +674,29 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
|||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.get("/api/assets/tags/refine")
|
||||||
|
@_require_assets_feature_enabled
|
||||||
|
async def get_tags_refine(request: web.Request) -> web.Response:
|
||||||
|
"""GET request to get tag histogram for filtered assets."""
|
||||||
|
query_dict = get_query_dict(request)
|
||||||
|
try:
|
||||||
|
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _build_validation_error_response("INVALID_QUERY", ve)
|
||||||
|
|
||||||
|
tag_counts = list_tag_histogram(
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
include_tags=q.include_tags,
|
||||||
|
exclude_tags=q.exclude_tags,
|
||||||
|
name_contains=q.name_contains,
|
||||||
|
metadata_filter=q.metadata_filter,
|
||||||
|
limit=q.limit,
|
||||||
|
)
|
||||||
|
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
|
||||||
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.post("/api/assets/seed")
|
@ROUTES.post("/api/assets/seed")
|
||||||
|
|||||||
@ -45,6 +45,8 @@ class ParsedUpload:
|
|||||||
user_metadata_raw: str | None
|
user_metadata_raw: str | None
|
||||||
provided_hash: str | None
|
provided_hash: str | None
|
||||||
provided_hash_exists: bool | None
|
provided_hash_exists: bool | None
|
||||||
|
provided_mime_type: str | None = None
|
||||||
|
provided_preview_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListAssetsQuery(BaseModel):
|
class ListAssetsQuery(BaseModel):
|
||||||
@ -98,11 +100,17 @@ class ListAssetsQuery(BaseModel):
|
|||||||
class UpdateAssetBody(BaseModel):
|
class UpdateAssetBody(BaseModel):
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
user_metadata: dict[str, Any] | None = None
|
user_metadata: dict[str, Any] | None = None
|
||||||
|
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_at_least_one_field(self):
|
def _validate_at_least_one_field(self):
|
||||||
if self.name is None and self.user_metadata is None:
|
if all(
|
||||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
v is None
|
||||||
|
for v in (self.name, self.user_metadata, self.preview_id)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Provide at least one of: name, user_metadata, preview_id."
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@ -110,9 +118,11 @@ class CreateFromHashBody(BaseModel):
|
|||||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
hash: str
|
hash: str
|
||||||
name: str
|
name: str | None = None
|
||||||
tags: list[str] = Field(default_factory=list)
|
tags: list[str] = Field(default_factory=list)
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
mime_type: str | None = None
|
||||||
|
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||||
|
|
||||||
@field_validator("hash")
|
@field_validator("hash")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -138,6 +148,44 @@ class CreateFromHashBody(BaseModel):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class TagsRefineQuery(BaseModel):
|
||||||
|
include_tags: list[str] = Field(default_factory=list)
|
||||||
|
exclude_tags: list[str] = Field(default_factory=list)
|
||||||
|
name_contains: str | None = None
|
||||||
|
metadata_filter: dict[str, Any] | None = None
|
||||||
|
limit: conint(ge=1, le=1000) = 100
|
||||||
|
|
||||||
|
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _split_csv_tags(cls, v):
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, str):
|
||||||
|
return [t.strip() for t in v.split(",") if t.strip()]
|
||||||
|
if isinstance(v, list):
|
||||||
|
out: list[str] = []
|
||||||
|
for item in v:
|
||||||
|
if isinstance(item, str):
|
||||||
|
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||||
|
return out
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("metadata_filter", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_metadata_json(cls, v):
|
||||||
|
if v is None or isinstance(v, dict):
|
||||||
|
return v
|
||||||
|
if isinstance(v, str) and v.strip():
|
||||||
|
try:
|
||||||
|
parsed = json.loads(v)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise ValueError("metadata_filter must be a JSON object")
|
||||||
|
return parsed
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class TagsListQuery(BaseModel):
|
class TagsListQuery(BaseModel):
|
||||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
@ -186,21 +234,25 @@ class TagsRemove(TagsAdd):
|
|||||||
class UploadAssetSpec(BaseModel):
|
class UploadAssetSpec(BaseModel):
|
||||||
"""Upload Asset operation.
|
"""Upload Asset operation.
|
||||||
|
|
||||||
- tags: ordered; first is root ('models'|'input'|'output');
|
- tags: optional list; if provided, first is root ('models'|'input'|'output');
|
||||||
if root == 'models', second must be a valid category
|
if root == 'models', second must be a valid category
|
||||||
- name: display name
|
- name: display name
|
||||||
- user_metadata: arbitrary JSON object (optional)
|
- user_metadata: arbitrary JSON object (optional)
|
||||||
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
||||||
|
- mime_type: optional MIME type override
|
||||||
|
- preview_id: optional asset_reference ID for preview
|
||||||
|
|
||||||
Files are stored using the content hash as filename stem.
|
Files are stored using the content hash as filename stem.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
tags: list[str] = Field(..., min_length=1)
|
tags: list[str] = Field(default_factory=list)
|
||||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
hash: str | None = Field(default=None)
|
hash: str | None = Field(default=None)
|
||||||
|
mime_type: str | None = Field(default=None)
|
||||||
|
preview_id: str | None = Field(default=None) # references an asset_reference id
|
||||||
|
|
||||||
@field_validator("hash", mode="before")
|
@field_validator("hash", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -279,7 +331,7 @@ class UploadAssetSpec(BaseModel):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_order(self):
|
def _validate_order(self):
|
||||||
if not self.tags:
|
if not self.tags:
|
||||||
raise ValueError("tags must be provided and non-empty")
|
raise ValueError("at least one tag is required for uploads")
|
||||||
root = self.tags[0]
|
root = self.tags[0]
|
||||||
if root not in {"models", "input", "output"}:
|
if root not in {"models", "input", "output"}:
|
||||||
raise ValueError("first tag must be one of: models, input, output")
|
raise ValueError("first tag must be one of: models, input, output")
|
||||||
|
|||||||
@ -4,7 +4,10 @@ from typing import Any
|
|||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||||
|
|
||||||
|
|
||||||
class AssetSummary(BaseModel):
|
class Asset(BaseModel):
|
||||||
|
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
|
||||||
|
``id`` here is the AssetReference id, not the content-addressed Asset id."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
asset_hash: str | None = None
|
asset_hash: str | None = None
|
||||||
@ -12,8 +15,14 @@ class AssetSummary(BaseModel):
|
|||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
tags: list[str] = Field(default_factory=list)
|
tags: list[str] = Field(default_factory=list)
|
||||||
preview_url: str | None = None
|
preview_url: str | None = None
|
||||||
created_at: datetime | None = None
|
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||||
updated_at: datetime | None = None
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
is_immutable: bool = False
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
job_id: str | None = None
|
||||||
|
prompt_id: str | None = None # deprecated: use job_id
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
last_access_time: datetime | None = None
|
last_access_time: datetime | None = None
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
@ -23,50 +32,16 @@ class AssetSummary(BaseModel):
|
|||||||
return v.isoformat() if v else None
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
|
class AssetCreated(Asset):
|
||||||
|
created_new: bool
|
||||||
|
|
||||||
|
|
||||||
class AssetsList(BaseModel):
|
class AssetsList(BaseModel):
|
||||||
assets: list[AssetSummary]
|
assets: list[Asset]
|
||||||
total: int
|
total: int
|
||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
class AssetUpdated(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
asset_hash: str | None = None
|
|
||||||
tags: list[str] = Field(default_factory=list)
|
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
updated_at: datetime | None = None
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
@field_serializer("updated_at")
|
|
||||||
def _serialize_updated_at(self, v: datetime | None, _info):
|
|
||||||
return v.isoformat() if v else None
|
|
||||||
|
|
||||||
|
|
||||||
class AssetDetail(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
asset_hash: str | None = None
|
|
||||||
size: int | None = None
|
|
||||||
mime_type: str | None = None
|
|
||||||
tags: list[str] = Field(default_factory=list)
|
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
preview_id: str | None = None
|
|
||||||
created_at: datetime | None = None
|
|
||||||
last_access_time: datetime | None = None
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
@field_serializer("created_at", "last_access_time")
|
|
||||||
def _serialize_datetime(self, v: datetime | None, _info):
|
|
||||||
return v.isoformat() if v else None
|
|
||||||
|
|
||||||
|
|
||||||
class AssetCreated(AssetDetail):
|
|
||||||
created_new: bool
|
|
||||||
|
|
||||||
|
|
||||||
class TagUsage(BaseModel):
|
class TagUsage(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
@ -91,3 +66,7 @@ class TagsRemove(BaseModel):
|
|||||||
removed: list[str] = Field(default_factory=list)
|
removed: list[str] = Field(default_factory=list)
|
||||||
not_present: list[str] = Field(default_factory=list)
|
not_present: list[str] = Field(default_factory=list)
|
||||||
total_tags: list[str] = Field(default_factory=list)
|
total_tags: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TagHistogram(BaseModel):
|
||||||
|
tag_counts: dict[str, int]
|
||||||
|
|||||||
@ -52,6 +52,8 @@ async def parse_multipart_upload(
|
|||||||
user_metadata_raw: str | None = None
|
user_metadata_raw: str | None = None
|
||||||
provided_hash: str | None = None
|
provided_hash: str | None = None
|
||||||
provided_hash_exists: bool | None = None
|
provided_hash_exists: bool | None = None
|
||||||
|
provided_mime_type: str | None = None
|
||||||
|
provided_preview_id: str | None = None
|
||||||
|
|
||||||
file_written = 0
|
file_written = 0
|
||||||
tmp_path: str | None = None
|
tmp_path: str | None = None
|
||||||
@ -128,6 +130,16 @@ async def parse_multipart_upload(
|
|||||||
provided_name = (await field.text()) or None
|
provided_name = (await field.text()) or None
|
||||||
elif fname == "user_metadata":
|
elif fname == "user_metadata":
|
||||||
user_metadata_raw = (await field.text()) or None
|
user_metadata_raw = (await field.text()) or None
|
||||||
|
elif fname == "id":
|
||||||
|
raise UploadError(
|
||||||
|
400,
|
||||||
|
"UNSUPPORTED_FIELD",
|
||||||
|
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
|
||||||
|
)
|
||||||
|
elif fname == "mime_type":
|
||||||
|
provided_mime_type = ((await field.text()) or "").strip() or None
|
||||||
|
elif fname == "preview_id":
|
||||||
|
provided_preview_id = ((await field.text()) or "").strip() or None
|
||||||
|
|
||||||
if not file_present and not (provided_hash and provided_hash_exists):
|
if not file_present and not (provided_hash and provided_hash_exists):
|
||||||
raise UploadError(
|
raise UploadError(
|
||||||
@ -152,6 +164,8 @@ async def parse_multipart_upload(
|
|||||||
user_metadata_raw=user_metadata_raw,
|
user_metadata_raw=user_metadata_raw,
|
||||||
provided_hash=provided_hash,
|
provided_hash=provided_hash,
|
||||||
provided_hash_exists=provided_hash_exists,
|
provided_hash_exists=provided_hash_exists,
|
||||||
|
provided_mime_type=provided_mime_type,
|
||||||
|
provided_preview_id=provided_preview_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -45,13 +45,7 @@ class Asset(Base):
|
|||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
preview_of: Mapped[list[AssetReference]] = relationship(
|
# preview_id on AssetReference is a self-referential FK to asset_references.id
|
||||||
"AssetReference",
|
|
||||||
back_populates="preview_asset",
|
|
||||||
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
|
|
||||||
foreign_keys=lambda: [AssetReference.preview_id],
|
|
||||||
viewonly=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("uq_assets_hash", "hash", unique=True),
|
Index("uq_assets_hash", "hash", unique=True),
|
||||||
@ -91,11 +85,15 @@ class AssetReference(Base):
|
|||||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
preview_id: Mapped[str | None] = mapped_column(
|
preview_id: Mapped[str | None] = mapped_column(
|
||||||
String(36), ForeignKey("assets.id", ondelete="SET NULL")
|
String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
|
||||||
)
|
)
|
||||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||||
JSON(none_as_null=True)
|
JSON(none_as_null=True)
|
||||||
)
|
)
|
||||||
|
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||||
|
JSON(none_as_null=True), nullable=True, default=None
|
||||||
|
)
|
||||||
|
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||||
)
|
)
|
||||||
@ -115,10 +113,10 @@ class AssetReference(Base):
|
|||||||
foreign_keys=[asset_id],
|
foreign_keys=[asset_id],
|
||||||
lazy="selectin",
|
lazy="selectin",
|
||||||
)
|
)
|
||||||
preview_asset: Mapped[Asset | None] = relationship(
|
preview_ref: Mapped[AssetReference | None] = relationship(
|
||||||
"Asset",
|
"AssetReference",
|
||||||
back_populates="preview_of",
|
|
||||||
foreign_keys=[preview_id],
|
foreign_keys=[preview_id],
|
||||||
|
remote_side=lambda: [AssetReference.id],
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
||||||
@ -152,6 +150,7 @@ class AssetReference(Base):
|
|||||||
Index("ix_asset_references_created_at", "created_at"),
|
Index("ix_asset_references_created_at", "created_at"),
|
||||||
Index("ix_asset_references_last_access_time", "last_access_time"),
|
Index("ix_asset_references_last_access_time", "last_access_time"),
|
||||||
Index("ix_asset_references_deleted_at", "deleted_at"),
|
Index("ix_asset_references_deleted_at", "deleted_at"),
|
||||||
|
Index("ix_asset_references_preview_id", "preview_id"),
|
||||||
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
||||||
CheckConstraint(
|
CheckConstraint(
|
||||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||||
@ -192,6 +191,10 @@ class AssetReferenceMeta(Base):
|
|||||||
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
||||||
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
||||||
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
||||||
|
CheckConstraint(
|
||||||
|
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||||
|
name="has_value",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -31,16 +31,21 @@ from app.assets.database.queries.asset_reference import (
|
|||||||
get_unenriched_references,
|
get_unenriched_references,
|
||||||
get_unreferenced_unhashed_asset_ids,
|
get_unreferenced_unhashed_asset_ids,
|
||||||
insert_reference,
|
insert_reference,
|
||||||
|
list_all_file_paths_by_asset_id,
|
||||||
list_references_by_asset_id,
|
list_references_by_asset_id,
|
||||||
list_references_page,
|
list_references_page,
|
||||||
mark_references_missing_outside_prefixes,
|
mark_references_missing_outside_prefixes,
|
||||||
|
rebuild_metadata_projection,
|
||||||
|
reference_exists,
|
||||||
reference_exists_for_asset_id,
|
reference_exists_for_asset_id,
|
||||||
restore_references_by_paths,
|
restore_references_by_paths,
|
||||||
set_reference_metadata,
|
set_reference_metadata,
|
||||||
set_reference_preview,
|
set_reference_preview,
|
||||||
|
set_reference_system_metadata,
|
||||||
soft_delete_reference_by_id,
|
soft_delete_reference_by_id,
|
||||||
update_reference_access_time,
|
update_reference_access_time,
|
||||||
update_reference_name,
|
update_reference_name,
|
||||||
|
update_is_missing_by_asset_id,
|
||||||
update_reference_timestamps,
|
update_reference_timestamps,
|
||||||
update_reference_updated_at,
|
update_reference_updated_at,
|
||||||
upsert_reference,
|
upsert_reference,
|
||||||
@ -54,6 +59,7 @@ from app.assets.database.queries.tags import (
|
|||||||
bulk_insert_tags_and_meta,
|
bulk_insert_tags_and_meta,
|
||||||
ensure_tags_exist,
|
ensure_tags_exist,
|
||||||
get_reference_tags,
|
get_reference_tags,
|
||||||
|
list_tag_counts_for_filtered_assets,
|
||||||
list_tags_with_usage,
|
list_tags_with_usage,
|
||||||
remove_missing_tag_for_asset_id,
|
remove_missing_tag_for_asset_id,
|
||||||
remove_tags_from_reference,
|
remove_tags_from_reference,
|
||||||
@ -97,20 +103,26 @@ __all__ = [
|
|||||||
"get_unenriched_references",
|
"get_unenriched_references",
|
||||||
"get_unreferenced_unhashed_asset_ids",
|
"get_unreferenced_unhashed_asset_ids",
|
||||||
"insert_reference",
|
"insert_reference",
|
||||||
|
"list_all_file_paths_by_asset_id",
|
||||||
"list_references_by_asset_id",
|
"list_references_by_asset_id",
|
||||||
"list_references_page",
|
"list_references_page",
|
||||||
|
"list_tag_counts_for_filtered_assets",
|
||||||
"list_tags_with_usage",
|
"list_tags_with_usage",
|
||||||
"mark_references_missing_outside_prefixes",
|
"mark_references_missing_outside_prefixes",
|
||||||
"reassign_asset_references",
|
"reassign_asset_references",
|
||||||
|
"rebuild_metadata_projection",
|
||||||
|
"reference_exists",
|
||||||
"reference_exists_for_asset_id",
|
"reference_exists_for_asset_id",
|
||||||
"remove_missing_tag_for_asset_id",
|
"remove_missing_tag_for_asset_id",
|
||||||
"remove_tags_from_reference",
|
"remove_tags_from_reference",
|
||||||
"restore_references_by_paths",
|
"restore_references_by_paths",
|
||||||
"set_reference_metadata",
|
"set_reference_metadata",
|
||||||
"set_reference_preview",
|
"set_reference_preview",
|
||||||
|
"set_reference_system_metadata",
|
||||||
"soft_delete_reference_by_id",
|
"soft_delete_reference_by_id",
|
||||||
"set_reference_tags",
|
"set_reference_tags",
|
||||||
"update_asset_hash_and_mime",
|
"update_asset_hash_and_mime",
|
||||||
|
"update_is_missing_by_asset_id",
|
||||||
"update_reference_access_time",
|
"update_reference_access_time",
|
||||||
"update_reference_name",
|
"update_reference_name",
|
||||||
"update_reference_timestamps",
|
"update_reference_timestamps",
|
||||||
|
|||||||
@ -69,7 +69,7 @@ def upsert_asset(
|
|||||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||||
asset.size_bytes = int(size_bytes)
|
asset.size_bytes = int(size_bytes)
|
||||||
changed = True
|
changed = True
|
||||||
if mime_type and asset.mime_type != mime_type:
|
if mime_type and not asset.mime_type:
|
||||||
asset.mime_type = mime_type
|
asset.mime_type = mime_type
|
||||||
changed = True
|
changed = True
|
||||||
if changed:
|
if changed:
|
||||||
@ -118,7 +118,7 @@ def update_asset_hash_and_mime(
|
|||||||
return False
|
return False
|
||||||
if asset_hash is not None:
|
if asset_hash is not None:
|
||||||
asset.hash = asset_hash
|
asset.hash = asset_hash
|
||||||
if mime_type is not None:
|
if mime_type is not None and not asset.mime_type:
|
||||||
asset.mime_type = mime_type
|
asset.mime_type = mime_type
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from decimal import Decimal
|
|||||||
from typing import NamedTuple, Sequence
|
from typing import NamedTuple, Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import delete, exists, select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.dialects import sqlite
|
from sqlalchemy.dialects import sqlite
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session, noload
|
from sqlalchemy.orm import Session, noload
|
||||||
@ -24,12 +24,14 @@ from app.assets.database.models import (
|
|||||||
)
|
)
|
||||||
from app.assets.database.queries.common import (
|
from app.assets.database.queries.common import (
|
||||||
MAX_BIND_PARAMS,
|
MAX_BIND_PARAMS,
|
||||||
|
apply_metadata_filter,
|
||||||
|
apply_tag_filters,
|
||||||
build_prefix_like_conditions,
|
build_prefix_like_conditions,
|
||||||
build_visible_owner_clause,
|
build_visible_owner_clause,
|
||||||
calculate_rows_per_statement,
|
calculate_rows_per_statement,
|
||||||
iter_chunks,
|
iter_chunks,
|
||||||
)
|
)
|
||||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
from app.assets.helpers import escape_sql_like_string, get_utc_now
|
||||||
|
|
||||||
|
|
||||||
def _check_is_scalar(v):
|
def _check_is_scalar(v):
|
||||||
@ -44,15 +46,6 @@ def _check_is_scalar(v):
|
|||||||
|
|
||||||
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
||||||
"""Convert a scalar value to a typed projection row."""
|
"""Convert a scalar value to a typed projection row."""
|
||||||
if value is None:
|
|
||||||
return {
|
|
||||||
"key": key,
|
|
||||||
"ordinal": ordinal,
|
|
||||||
"val_str": None,
|
|
||||||
"val_num": None,
|
|
||||||
"val_bool": None,
|
|
||||||
"val_json": None,
|
|
||||||
}
|
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
|
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
|
||||||
if isinstance(value, (int, float, Decimal)):
|
if isinstance(value, (int, float, Decimal)):
|
||||||
@ -66,96 +59,19 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
|||||||
def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
||||||
"""Turn a metadata key/value into typed projection rows."""
|
"""Turn a metadata key/value into typed projection rows."""
|
||||||
if value is None:
|
if value is None:
|
||||||
return [_scalar_to_row(key, 0, None)]
|
return []
|
||||||
|
|
||||||
if _check_is_scalar(value):
|
if _check_is_scalar(value):
|
||||||
return [_scalar_to_row(key, 0, value)]
|
return [_scalar_to_row(key, 0, value)]
|
||||||
|
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
if all(_check_is_scalar(x) for x in value):
|
if all(_check_is_scalar(x) for x in value):
|
||||||
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
|
return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None]
|
||||||
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
|
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None]
|
||||||
|
|
||||||
return [{"key": key, "ordinal": 0, "val_json": value}]
|
return [{"key": key, "ordinal": 0, "val_json": value}]
|
||||||
|
|
||||||
|
|
||||||
def _apply_tag_filters(
|
|
||||||
stmt: sa.sql.Select,
|
|
||||||
include_tags: Sequence[str] | None = None,
|
|
||||||
exclude_tags: Sequence[str] | None = None,
|
|
||||||
) -> sa.sql.Select:
|
|
||||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
|
||||||
include_tags = normalize_tags(include_tags)
|
|
||||||
exclude_tags = normalize_tags(exclude_tags)
|
|
||||||
|
|
||||||
if include_tags:
|
|
||||||
for tag_name in include_tags:
|
|
||||||
stmt = stmt.where(
|
|
||||||
exists().where(
|
|
||||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
|
||||||
& (AssetReferenceTag.tag_name == tag_name)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if exclude_tags:
|
|
||||||
stmt = stmt.where(
|
|
||||||
~exists().where(
|
|
||||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
|
||||||
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_metadata_filter(
|
|
||||||
stmt: sa.sql.Select,
|
|
||||||
metadata_filter: dict | None = None,
|
|
||||||
) -> sa.sql.Select:
|
|
||||||
"""Apply filters using asset_reference_meta projection table."""
|
|
||||||
if not metadata_filter:
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
|
||||||
return sa.exists().where(
|
|
||||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
|
||||||
AssetReferenceMeta.key == key,
|
|
||||||
*preds,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
|
||||||
if value is None:
|
|
||||||
no_row_for_key = sa.not_(
|
|
||||||
sa.exists().where(
|
|
||||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
|
||||||
AssetReferenceMeta.key == key,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
null_row = _exists_for_pred(
|
|
||||||
key,
|
|
||||||
AssetReferenceMeta.val_json.is_(None),
|
|
||||||
AssetReferenceMeta.val_str.is_(None),
|
|
||||||
AssetReferenceMeta.val_num.is_(None),
|
|
||||||
AssetReferenceMeta.val_bool.is_(None),
|
|
||||||
)
|
|
||||||
return sa.or_(no_row_for_key, null_row)
|
|
||||||
|
|
||||||
if isinstance(value, bool):
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
|
||||||
if isinstance(value, (int, float, Decimal)):
|
|
||||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
|
||||||
if isinstance(value, str):
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
|
||||||
|
|
||||||
for k, v in metadata_filter.items():
|
|
||||||
if isinstance(v, list):
|
|
||||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
|
||||||
if ors:
|
|
||||||
stmt = stmt.where(sa.or_(*ors))
|
|
||||||
else:
|
|
||||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
|
|
||||||
def get_reference_by_id(
|
def get_reference_by_id(
|
||||||
@ -212,6 +128,21 @@ def reference_exists_for_asset_id(
|
|||||||
return session.execute(q).first() is not None
|
return session.execute(q).first() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def reference_exists(
|
||||||
|
session: Session,
|
||||||
|
reference_id: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True if a reference with the given ID exists (not soft-deleted)."""
|
||||||
|
q = (
|
||||||
|
select(sa.literal(True))
|
||||||
|
.select_from(AssetReference)
|
||||||
|
.where(AssetReference.id == reference_id)
|
||||||
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
return session.execute(q).first() is not None
|
||||||
|
|
||||||
|
|
||||||
def insert_reference(
|
def insert_reference(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
@ -336,8 +267,8 @@ def list_references_page(
|
|||||||
escaped, esc = escape_sql_like_string(name_contains)
|
escaped, esc = escape_sql_like_string(name_contains)
|
||||||
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||||
|
|
||||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||||
base = _apply_metadata_filter(base, metadata_filter)
|
base = apply_metadata_filter(base, metadata_filter)
|
||||||
|
|
||||||
sort = (sort or "created_at").lower()
|
sort = (sort or "created_at").lower()
|
||||||
order = (order or "desc").lower()
|
order = (order or "desc").lower()
|
||||||
@ -366,8 +297,8 @@ def list_references_page(
|
|||||||
count_stmt = count_stmt.where(
|
count_stmt = count_stmt.where(
|
||||||
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
||||||
)
|
)
|
||||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||||
|
|
||||||
total = int(session.execute(count_stmt).scalar_one() or 0)
|
total = int(session.execute(count_stmt).scalar_one() or 0)
|
||||||
refs = session.execute(base).unique().scalars().all()
|
refs = session.execute(base).unique().scalars().all()
|
||||||
@ -379,7 +310,7 @@ def list_references_page(
|
|||||||
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
||||||
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
||||||
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
||||||
.order_by(AssetReferenceTag.added_at)
|
.order_by(AssetReferenceTag.tag_name.asc())
|
||||||
)
|
)
|
||||||
for ref_id, tag_name in rows.all():
|
for ref_id, tag_name in rows.all():
|
||||||
tag_map[ref_id].append(tag_name)
|
tag_map[ref_id].append(tag_name)
|
||||||
@ -492,6 +423,42 @@ def update_reference_updated_at(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None:
|
||||||
|
"""Delete and rebuild AssetReferenceMeta rows from merged system+user metadata.
|
||||||
|
|
||||||
|
The merged dict is ``{**system_metadata, **user_metadata}`` so user keys
|
||||||
|
override system keys of the same name.
|
||||||
|
"""
|
||||||
|
session.execute(
|
||||||
|
delete(AssetReferenceMeta).where(
|
||||||
|
AssetReferenceMeta.asset_reference_id == ref.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})}
|
||||||
|
if not merged:
|
||||||
|
return
|
||||||
|
|
||||||
|
rows: list[AssetReferenceMeta] = []
|
||||||
|
for k, v in merged.items():
|
||||||
|
for r in convert_metadata_to_rows(k, v):
|
||||||
|
rows.append(
|
||||||
|
AssetReferenceMeta(
|
||||||
|
asset_reference_id=ref.id,
|
||||||
|
key=r["key"],
|
||||||
|
ordinal=int(r["ordinal"]),
|
||||||
|
val_str=r.get("val_str"),
|
||||||
|
val_num=r.get("val_num"),
|
||||||
|
val_bool=r.get("val_bool"),
|
||||||
|
val_json=r.get("val_json"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if rows:
|
||||||
|
session.add_all(rows)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
|
||||||
def set_reference_metadata(
|
def set_reference_metadata(
|
||||||
session: Session,
|
session: Session,
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
@ -505,33 +472,24 @@ def set_reference_metadata(
|
|||||||
ref.updated_at = get_utc_now()
|
ref.updated_at = get_utc_now()
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
session.execute(
|
rebuild_metadata_projection(session, ref)
|
||||||
delete(AssetReferenceMeta).where(
|
|
||||||
AssetReferenceMeta.asset_reference_id == reference_id
|
|
||||||
)
|
def set_reference_system_metadata(
|
||||||
)
|
session: Session,
|
||||||
|
reference_id: str,
|
||||||
|
system_metadata: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set system_metadata on a reference and rebuild the merged projection."""
|
||||||
|
ref = session.get(AssetReference, reference_id)
|
||||||
|
if not ref:
|
||||||
|
raise ValueError(f"AssetReference {reference_id} not found")
|
||||||
|
|
||||||
|
ref.system_metadata = system_metadata or {}
|
||||||
|
ref.updated_at = get_utc_now()
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
if not user_metadata:
|
rebuild_metadata_projection(session, ref)
|
||||||
return
|
|
||||||
|
|
||||||
rows: list[AssetReferenceMeta] = []
|
|
||||||
for k, v in user_metadata.items():
|
|
||||||
for r in convert_metadata_to_rows(k, v):
|
|
||||||
rows.append(
|
|
||||||
AssetReferenceMeta(
|
|
||||||
asset_reference_id=reference_id,
|
|
||||||
key=r["key"],
|
|
||||||
ordinal=int(r["ordinal"]),
|
|
||||||
val_str=r.get("val_str"),
|
|
||||||
val_num=r.get("val_num"),
|
|
||||||
val_bool=r.get("val_bool"),
|
|
||||||
val_json=r.get("val_json"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if rows:
|
|
||||||
session.add_all(rows)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
|
|
||||||
def delete_reference_by_id(
|
def delete_reference_by_id(
|
||||||
@ -571,19 +529,19 @@ def soft_delete_reference_by_id(
|
|||||||
def set_reference_preview(
|
def set_reference_preview(
|
||||||
session: Session,
|
session: Session,
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
preview_asset_id: str | None = None,
|
preview_reference_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||||
ref = session.get(AssetReference, reference_id)
|
ref = session.get(AssetReference, reference_id)
|
||||||
if not ref:
|
if not ref:
|
||||||
raise ValueError(f"AssetReference {reference_id} not found")
|
raise ValueError(f"AssetReference {reference_id} not found")
|
||||||
|
|
||||||
if preview_asset_id is None:
|
if preview_reference_id is None:
|
||||||
ref.preview_id = None
|
ref.preview_id = None
|
||||||
else:
|
else:
|
||||||
if not session.get(Asset, preview_asset_id):
|
if not session.get(AssetReference, preview_reference_id):
|
||||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
raise ValueError(f"Preview AssetReference {preview_reference_id} not found")
|
||||||
ref.preview_id = preview_asset_id
|
ref.preview_id = preview_reference_id
|
||||||
|
|
||||||
ref.updated_at = get_utc_now()
|
ref.updated_at = get_utc_now()
|
||||||
session.flush()
|
session.flush()
|
||||||
@ -609,6 +567,8 @@ def list_references_by_asset_id(
|
|||||||
session.execute(
|
session.execute(
|
||||||
select(AssetReference)
|
select(AssetReference)
|
||||||
.where(AssetReference.asset_id == asset_id)
|
.where(AssetReference.asset_id == asset_id)
|
||||||
|
.where(AssetReference.is_missing == False) # noqa: E712
|
||||||
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
.order_by(AssetReference.id.asc())
|
.order_by(AssetReference.id.asc())
|
||||||
)
|
)
|
||||||
.scalars()
|
.scalars()
|
||||||
@ -616,6 +576,25 @@ def list_references_by_asset_id(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_all_file_paths_by_asset_id(
|
||||||
|
session: Session,
|
||||||
|
asset_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return every file_path for an asset, including soft-deleted/missing refs.
|
||||||
|
|
||||||
|
Used for orphan cleanup where all on-disk files must be removed.
|
||||||
|
"""
|
||||||
|
return list(
|
||||||
|
session.execute(
|
||||||
|
select(AssetReference.file_path)
|
||||||
|
.where(AssetReference.asset_id == asset_id)
|
||||||
|
.where(AssetReference.file_path.isnot(None))
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def upsert_reference(
|
def upsert_reference(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
@ -855,6 +834,22 @@ def bulk_update_is_missing(
|
|||||||
return total
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def update_is_missing_by_asset_id(
|
||||||
|
session: Session, asset_id: str, value: bool
|
||||||
|
) -> int:
|
||||||
|
"""Set is_missing flag for ALL references belonging to an asset.
|
||||||
|
|
||||||
|
Returns: Number of rows updated
|
||||||
|
"""
|
||||||
|
result = session.execute(
|
||||||
|
sa.update(AssetReference)
|
||||||
|
.where(AssetReference.asset_id == asset_id)
|
||||||
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
|
.values(is_missing=value)
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
|
||||||
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
|
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
|
||||||
"""Delete references by their IDs.
|
"""Delete references by their IDs.
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
"""Shared utilities for database query modules."""
|
"""Shared utilities for database query modules."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Iterable
|
from decimal import Decimal
|
||||||
|
from typing import Iterable, Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import exists
|
||||||
|
|
||||||
from app.assets.database.models import AssetReference
|
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
|
||||||
from app.assets.helpers import escape_sql_like_string
|
from app.assets.helpers import escape_sql_like_string, normalize_tags
|
||||||
|
|
||||||
MAX_BIND_PARAMS = 800
|
MAX_BIND_PARAMS = 800
|
||||||
|
|
||||||
@ -52,3 +54,74 @@ def build_prefix_like_conditions(
|
|||||||
escaped, esc = escape_sql_like_string(base)
|
escaped, esc = escape_sql_like_string(base)
|
||||||
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
|
|
||||||
|
def apply_tag_filters(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||||
|
include_tags = normalize_tags(include_tags)
|
||||||
|
exclude_tags = normalize_tags(exclude_tags)
|
||||||
|
|
||||||
|
if include_tags:
|
||||||
|
for tag_name in include_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
exists().where(
|
||||||
|
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||||
|
& (AssetReferenceTag.tag_name == tag_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if exclude_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
~exists().where(
|
||||||
|
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||||
|
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def apply_metadata_filter(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
metadata_filter: dict | None = None,
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""Apply filters using asset_reference_meta projection table."""
|
||||||
|
if not metadata_filter:
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||||
|
return sa.exists().where(
|
||||||
|
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||||
|
AssetReferenceMeta.key == key,
|
||||||
|
*preds,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||||
|
if value is None:
|
||||||
|
return sa.not_(
|
||||||
|
sa.exists().where(
|
||||||
|
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||||
|
AssetReferenceMeta.key == key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
||||||
|
if isinstance(value, (int, float, Decimal)):
|
||||||
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
||||||
|
|
||||||
|
for k, v in metadata_filter.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||||
|
if ors:
|
||||||
|
stmt = stmt.where(sa.or_(*ors))
|
||||||
|
else:
|
||||||
|
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||||
|
return stmt
|
||||||
|
|||||||
@ -8,12 +8,15 @@ from sqlalchemy.exc import IntegrityError
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import (
|
from app.assets.database.models import (
|
||||||
|
Asset,
|
||||||
AssetReference,
|
AssetReference,
|
||||||
AssetReferenceMeta,
|
AssetReferenceMeta,
|
||||||
AssetReferenceTag,
|
AssetReferenceTag,
|
||||||
Tag,
|
Tag,
|
||||||
)
|
)
|
||||||
from app.assets.database.queries.common import (
|
from app.assets.database.queries.common import (
|
||||||
|
apply_metadata_filter,
|
||||||
|
apply_tag_filters,
|
||||||
build_visible_owner_clause,
|
build_visible_owner_clause,
|
||||||
iter_row_chunks,
|
iter_row_chunks,
|
||||||
)
|
)
|
||||||
@ -72,9 +75,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
|||||||
tag_name
|
tag_name
|
||||||
for (tag_name,) in (
|
for (tag_name,) in (
|
||||||
session.execute(
|
session.execute(
|
||||||
select(AssetReferenceTag.tag_name).where(
|
select(AssetReferenceTag.tag_name)
|
||||||
AssetReferenceTag.asset_reference_id == reference_id
|
.where(AssetReferenceTag.asset_reference_id == reference_id)
|
||||||
)
|
.order_by(AssetReferenceTag.tag_name.asc())
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
]
|
]
|
||||||
@ -117,7 +120,7 @@ def set_reference_tags(
|
|||||||
)
|
)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
|
return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
|
||||||
|
|
||||||
|
|
||||||
def add_tags_to_reference(
|
def add_tags_to_reference(
|
||||||
@ -272,6 +275,12 @@ def list_tags_with_usage(
|
|||||||
.select_from(AssetReferenceTag)
|
.select_from(AssetReferenceTag)
|
||||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||||
.where(build_visible_owner_clause(owner_id))
|
.where(build_visible_owner_clause(owner_id))
|
||||||
|
.where(
|
||||||
|
sa.or_(
|
||||||
|
AssetReference.is_missing == False, # noqa: E712
|
||||||
|
AssetReferenceTag.tag_name == "missing",
|
||||||
|
)
|
||||||
|
)
|
||||||
.where(AssetReference.deleted_at.is_(None))
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
.group_by(AssetReferenceTag.tag_name)
|
.group_by(AssetReferenceTag.tag_name)
|
||||||
.subquery()
|
.subquery()
|
||||||
@ -308,6 +317,12 @@ def list_tags_with_usage(
|
|||||||
select(AssetReferenceTag.tag_name)
|
select(AssetReferenceTag.tag_name)
|
||||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||||
.where(build_visible_owner_clause(owner_id))
|
.where(build_visible_owner_clause(owner_id))
|
||||||
|
.where(
|
||||||
|
sa.or_(
|
||||||
|
AssetReference.is_missing == False, # noqa: E712
|
||||||
|
AssetReferenceTag.tag_name == "missing",
|
||||||
|
)
|
||||||
|
)
|
||||||
.where(AssetReference.deleted_at.is_(None))
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
.group_by(AssetReferenceTag.tag_name)
|
.group_by(AssetReferenceTag.tag_name)
|
||||||
)
|
)
|
||||||
@ -320,6 +335,53 @@ def list_tags_with_usage(
|
|||||||
return rows_norm, int(total or 0)
|
return rows_norm, int(total or 0)
|
||||||
|
|
||||||
|
|
||||||
|
def list_tag_counts_for_filtered_assets(
|
||||||
|
session: Session,
|
||||||
|
owner_id: str = "",
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
name_contains: str | None = None,
|
||||||
|
metadata_filter: dict | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""Return tag counts for assets matching the given filters.
|
||||||
|
|
||||||
|
Uses the same filtering logic as list_references_page but returns
|
||||||
|
{tag_name: count} instead of paginated references.
|
||||||
|
"""
|
||||||
|
# Build a subquery of matching reference IDs
|
||||||
|
ref_sq = (
|
||||||
|
select(AssetReference.id)
|
||||||
|
.join(Asset, Asset.id == AssetReference.asset_id)
|
||||||
|
.where(build_visible_owner_clause(owner_id))
|
||||||
|
.where(AssetReference.is_missing == False) # noqa: E712
|
||||||
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
|
)
|
||||||
|
|
||||||
|
if name_contains:
|
||||||
|
escaped, esc = escape_sql_like_string(name_contains)
|
||||||
|
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||||
|
|
||||||
|
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
|
||||||
|
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
|
||||||
|
ref_sq = ref_sq.subquery()
|
||||||
|
|
||||||
|
# Count tags across those references
|
||||||
|
q = (
|
||||||
|
select(
|
||||||
|
AssetReferenceTag.tag_name,
|
||||||
|
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
|
||||||
|
)
|
||||||
|
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
|
||||||
|
.group_by(AssetReferenceTag.tag_name)
|
||||||
|
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = session.execute(q).all()
|
||||||
|
return {tag_name: int(cnt) for tag_name, cnt in rows}
|
||||||
|
|
||||||
|
|
||||||
def bulk_insert_tags_and_meta(
|
def bulk_insert_tags_and_meta(
|
||||||
session: Session,
|
session: Session,
|
||||||
tag_rows: list[dict],
|
tag_rows: list[dict],
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from app.assets.database.queries import (
|
|||||||
mark_references_missing_outside_prefixes,
|
mark_references_missing_outside_prefixes,
|
||||||
reassign_asset_references,
|
reassign_asset_references,
|
||||||
remove_missing_tag_for_asset_id,
|
remove_missing_tag_for_asset_id,
|
||||||
set_reference_metadata,
|
set_reference_system_metadata,
|
||||||
update_asset_hash_and_mime,
|
update_asset_hash_and_mime,
|
||||||
)
|
)
|
||||||
from app.assets.services.bulk_ingest import (
|
from app.assets.services.bulk_ingest import (
|
||||||
@ -490,8 +490,8 @@ def enrich_asset(
|
|||||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||||
|
|
||||||
if extract_metadata and metadata:
|
if extract_metadata and metadata:
|
||||||
user_metadata = metadata.to_user_metadata()
|
system_metadata = metadata.to_user_metadata()
|
||||||
set_reference_metadata(session, reference_id, user_metadata)
|
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||||
|
|
||||||
if full_hash:
|
if full_hash:
|
||||||
existing = get_asset_by_hash(session, full_hash)
|
existing = get_asset_by_hash(session, full_hash)
|
||||||
|
|||||||
@ -16,10 +16,12 @@ from app.assets.database.queries import (
|
|||||||
get_reference_by_id,
|
get_reference_by_id,
|
||||||
get_reference_with_owner_check,
|
get_reference_with_owner_check,
|
||||||
list_references_page,
|
list_references_page,
|
||||||
|
list_all_file_paths_by_asset_id,
|
||||||
list_references_by_asset_id,
|
list_references_by_asset_id,
|
||||||
set_reference_metadata,
|
set_reference_metadata,
|
||||||
set_reference_preview,
|
set_reference_preview,
|
||||||
set_reference_tags,
|
set_reference_tags,
|
||||||
|
update_asset_hash_and_mime,
|
||||||
update_reference_access_time,
|
update_reference_access_time,
|
||||||
update_reference_name,
|
update_reference_name,
|
||||||
update_reference_updated_at,
|
update_reference_updated_at,
|
||||||
@ -67,6 +69,8 @@ def update_asset_metadata(
|
|||||||
user_metadata: UserMetadata = None,
|
user_metadata: UserMetadata = None,
|
||||||
tag_origin: str = "manual",
|
tag_origin: str = "manual",
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
|
mime_type: str | None = None,
|
||||||
|
preview_id: str | None = None,
|
||||||
) -> AssetDetailResult:
|
) -> AssetDetailResult:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||||
@ -103,6 +107,21 @@ def update_asset_metadata(
|
|||||||
)
|
)
|
||||||
touched = True
|
touched = True
|
||||||
|
|
||||||
|
if mime_type is not None:
|
||||||
|
updated = update_asset_hash_and_mime(
|
||||||
|
session, asset_id=ref.asset_id, mime_type=mime_type
|
||||||
|
)
|
||||||
|
if updated:
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
if preview_id is not None:
|
||||||
|
set_reference_preview(
|
||||||
|
session,
|
||||||
|
reference_id=reference_id,
|
||||||
|
preview_reference_id=preview_id,
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
|
||||||
if touched and user_metadata is None:
|
if touched and user_metadata is None:
|
||||||
update_reference_updated_at(session, reference_id=reference_id)
|
update_reference_updated_at(session, reference_id=reference_id)
|
||||||
|
|
||||||
@ -159,11 +178,9 @@ def delete_asset_reference(
|
|||||||
session.commit()
|
session.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Orphaned asset - delete it and its files
|
# Orphaned asset - gather ALL file paths (including
|
||||||
refs = list_references_by_asset_id(session, asset_id=asset_id)
|
# soft-deleted / missing refs) so their on-disk files get cleaned up.
|
||||||
file_paths = [
|
file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
|
||||||
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
|
|
||||||
]
|
|
||||||
# Also include the just-deleted file path
|
# Also include the just-deleted file path
|
||||||
if file_path:
|
if file_path:
|
||||||
file_paths.append(file_path)
|
file_paths.append(file_path)
|
||||||
@ -185,7 +202,7 @@ def delete_asset_reference(
|
|||||||
|
|
||||||
def set_asset_preview(
|
def set_asset_preview(
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
preview_asset_id: str | None = None,
|
preview_reference_id: str | None = None,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
) -> AssetDetailResult:
|
) -> AssetDetailResult:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
@ -194,7 +211,7 @@ def set_asset_preview(
|
|||||||
set_reference_preview(
|
set_reference_preview(
|
||||||
session,
|
session,
|
||||||
reference_id=reference_id,
|
reference_id=reference_id,
|
||||||
preview_asset_id=preview_asset_id,
|
preview_reference_id=preview_reference_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = fetch_reference_asset_and_tags(
|
result = fetch_reference_asset_and_tags(
|
||||||
@ -263,6 +280,47 @@ def list_assets_page(
|
|||||||
return ListAssetsResult(items=items, total=total)
|
return ListAssetsResult(items=items, total=total)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_hash_to_path(
|
||||||
|
asset_hash: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> DownloadResolutionResult | None:
|
||||||
|
"""Resolve a blake3 hash to an on-disk file path.
|
||||||
|
|
||||||
|
Only references visible to *owner_id* are considered (owner-less
|
||||||
|
references are always visible).
|
||||||
|
|
||||||
|
Returns a DownloadResolutionResult with abs_path, content_type, and
|
||||||
|
download_name, or None if no asset or live path is found.
|
||||||
|
"""
|
||||||
|
with create_session() as session:
|
||||||
|
asset = queries_get_asset_by_hash(session, asset_hash)
|
||||||
|
if not asset:
|
||||||
|
return None
|
||||||
|
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
||||||
|
visible = [
|
||||||
|
r for r in refs
|
||||||
|
if r.owner_id == "" or r.owner_id == owner_id
|
||||||
|
]
|
||||||
|
abs_path = select_best_live_path(visible)
|
||||||
|
if not abs_path:
|
||||||
|
return None
|
||||||
|
display_name = os.path.basename(abs_path)
|
||||||
|
for ref in visible:
|
||||||
|
if ref.file_path == abs_path and ref.name:
|
||||||
|
display_name = ref.name
|
||||||
|
break
|
||||||
|
ctype = (
|
||||||
|
asset.mime_type
|
||||||
|
or mimetypes.guess_type(display_name)[0]
|
||||||
|
or "application/octet-stream"
|
||||||
|
)
|
||||||
|
return DownloadResolutionResult(
|
||||||
|
abs_path=abs_path,
|
||||||
|
content_type=ctype,
|
||||||
|
download_name=display_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def resolve_asset_for_download(
|
def resolve_asset_for_download(
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
|
|||||||
@ -11,13 +11,14 @@ from app.assets.database.queries import (
|
|||||||
add_tags_to_reference,
|
add_tags_to_reference,
|
||||||
fetch_reference_and_asset,
|
fetch_reference_and_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_existing_asset_ids,
|
|
||||||
get_reference_by_file_path,
|
get_reference_by_file_path,
|
||||||
get_reference_tags,
|
get_reference_tags,
|
||||||
get_or_create_reference,
|
get_or_create_reference,
|
||||||
|
reference_exists,
|
||||||
remove_missing_tag_for_asset_id,
|
remove_missing_tag_for_asset_id,
|
||||||
set_reference_metadata,
|
set_reference_metadata,
|
||||||
set_reference_tags,
|
set_reference_tags,
|
||||||
|
update_asset_hash_and_mime,
|
||||||
upsert_asset,
|
upsert_asset,
|
||||||
upsert_reference,
|
upsert_reference,
|
||||||
validate_tags_exist,
|
validate_tags_exist,
|
||||||
@ -26,6 +27,7 @@ from app.assets.helpers import normalize_tags
|
|||||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||||
from app.assets.services.path_utils import (
|
from app.assets.services.path_utils import (
|
||||||
compute_relative_filename,
|
compute_relative_filename,
|
||||||
|
get_name_and_tags_from_asset_path,
|
||||||
resolve_destination_from_tags,
|
resolve_destination_from_tags,
|
||||||
validate_path_within_base,
|
validate_path_within_base,
|
||||||
)
|
)
|
||||||
@ -65,7 +67,7 @@ def _ingest_file_from_path(
|
|||||||
|
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
if preview_id:
|
if preview_id:
|
||||||
if preview_id not in get_existing_asset_ids(session, [preview_id]):
|
if not reference_exists(session, preview_id):
|
||||||
preview_id = None
|
preview_id = None
|
||||||
|
|
||||||
asset, asset_created, asset_updated = upsert_asset(
|
asset, asset_created, asset_updated = upsert_asset(
|
||||||
@ -135,6 +137,8 @@ def _register_existing_asset(
|
|||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
tag_origin: str = "manual",
|
tag_origin: str = "manual",
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
|
mime_type: str | None = None,
|
||||||
|
preview_id: str | None = None,
|
||||||
) -> RegisterAssetResult:
|
) -> RegisterAssetResult:
|
||||||
user_metadata = user_metadata or {}
|
user_metadata = user_metadata or {}
|
||||||
|
|
||||||
@ -143,14 +147,25 @@ def _register_existing_asset(
|
|||||||
if not asset:
|
if not asset:
|
||||||
raise ValueError(f"No asset with hash {asset_hash}")
|
raise ValueError(f"No asset with hash {asset_hash}")
|
||||||
|
|
||||||
|
if mime_type and not asset.mime_type:
|
||||||
|
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
|
||||||
|
|
||||||
|
if preview_id:
|
||||||
|
if not reference_exists(session, preview_id):
|
||||||
|
preview_id = None
|
||||||
|
|
||||||
ref, ref_created = get_or_create_reference(
|
ref, ref_created = get_or_create_reference(
|
||||||
session,
|
session,
|
||||||
asset_id=asset.id,
|
asset_id=asset.id,
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
name=name,
|
name=name,
|
||||||
|
preview_id=preview_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not ref_created:
|
if not ref_created:
|
||||||
|
if preview_id and ref.preview_id != preview_id:
|
||||||
|
ref.preview_id = preview_id
|
||||||
|
|
||||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||||
result = RegisterAssetResult(
|
result = RegisterAssetResult(
|
||||||
ref=extract_reference_data(ref),
|
ref=extract_reference_data(ref),
|
||||||
@ -242,6 +257,8 @@ def upload_from_temp_path(
|
|||||||
client_filename: str | None = None,
|
client_filename: str | None = None,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
expected_hash: str | None = None,
|
expected_hash: str | None = None,
|
||||||
|
mime_type: str | None = None,
|
||||||
|
preview_id: str | None = None,
|
||||||
) -> UploadResult:
|
) -> UploadResult:
|
||||||
try:
|
try:
|
||||||
digest, _ = hashing.compute_blake3_hash(temp_path)
|
digest, _ = hashing.compute_blake3_hash(temp_path)
|
||||||
@ -270,6 +287,8 @@ def upload_from_temp_path(
|
|||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
tag_origin="manual",
|
tag_origin="manual",
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
|
mime_type=mime_type,
|
||||||
|
preview_id=preview_id,
|
||||||
)
|
)
|
||||||
return UploadResult(
|
return UploadResult(
|
||||||
ref=result.ref,
|
ref=result.ref,
|
||||||
@ -291,7 +310,7 @@ def upload_from_temp_path(
|
|||||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||||
validate_path_within_base(dest_abs, base_dir)
|
validate_path_within_base(dest_abs, base_dir)
|
||||||
|
|
||||||
content_type = (
|
content_type = mime_type or (
|
||||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||||
or "application/octet-stream"
|
or "application/octet-stream"
|
||||||
@ -315,7 +334,7 @@ def upload_from_temp_path(
|
|||||||
mime_type=content_type,
|
mime_type=content_type,
|
||||||
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
preview_id=None,
|
preview_id=preview_id,
|
||||||
user_metadata=user_metadata or {},
|
user_metadata=user_metadata or {},
|
||||||
tags=tags,
|
tags=tags,
|
||||||
tag_origin="manual",
|
tag_origin="manual",
|
||||||
@ -342,30 +361,99 @@ def upload_from_temp_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_file_in_place(
|
||||||
|
abs_path: str,
|
||||||
|
name: str,
|
||||||
|
tags: list[str],
|
||||||
|
owner_id: str = "",
|
||||||
|
mime_type: str | None = None,
|
||||||
|
) -> UploadResult:
|
||||||
|
"""Register an already-saved file in the asset database without moving it.
|
||||||
|
|
||||||
|
Tags are derived from the filesystem path (root category + subfolder names),
|
||||||
|
merged with any caller-provided tags, matching the behavior of the scanner.
|
||||||
|
If the path is not under a known root, only the caller-provided tags are used.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||||
|
except ValueError:
|
||||||
|
path_tags = []
|
||||||
|
merged_tags = normalize_tags([*path_tags, *tags])
|
||||||
|
|
||||||
|
try:
|
||||||
|
digest, _ = hashing.compute_blake3_hash(abs_path)
|
||||||
|
except ImportError as e:
|
||||||
|
raise DependencyMissingError(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"failed to hash file: {e}")
|
||||||
|
asset_hash = "blake3:" + digest
|
||||||
|
|
||||||
|
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||||
|
content_type = mime_type or (
|
||||||
|
mimetypes.guess_type(abs_path, strict=False)[0]
|
||||||
|
or "application/octet-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
ingest_result = _ingest_file_from_path(
|
||||||
|
abs_path=abs_path,
|
||||||
|
asset_hash=asset_hash,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mtime_ns=mtime_ns,
|
||||||
|
mime_type=content_type,
|
||||||
|
info_name=_sanitize_filename(name, fallback=digest),
|
||||||
|
owner_id=owner_id,
|
||||||
|
tags=merged_tags,
|
||||||
|
tag_origin="upload",
|
||||||
|
require_existing_tags=False,
|
||||||
|
)
|
||||||
|
reference_id = ingest_result.reference_id
|
||||||
|
if not reference_id:
|
||||||
|
raise RuntimeError("failed to create asset reference")
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
pair = fetch_reference_and_asset(
|
||||||
|
session, reference_id=reference_id, owner_id=owner_id
|
||||||
|
)
|
||||||
|
if not pair:
|
||||||
|
raise RuntimeError("inconsistent DB state after ingest")
|
||||||
|
ref, asset = pair
|
||||||
|
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||||
|
|
||||||
|
return UploadResult(
|
||||||
|
ref=extract_reference_data(ref),
|
||||||
|
asset=extract_asset_data(asset),
|
||||||
|
tags=tag_names,
|
||||||
|
created_new=ingest_result.asset_created,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_from_hash(
|
def create_from_hash(
|
||||||
hash_str: str,
|
hash_str: str,
|
||||||
name: str,
|
name: str,
|
||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
user_metadata: dict | None = None,
|
user_metadata: dict | None = None,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
|
mime_type: str | None = None,
|
||||||
|
preview_id: str | None = None,
|
||||||
) -> UploadResult | None:
|
) -> UploadResult | None:
|
||||||
canonical = hash_str.strip().lower()
|
canonical = hash_str.strip().lower()
|
||||||
|
|
||||||
with create_session() as session:
|
try:
|
||||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
result = _register_existing_asset(
|
||||||
if not asset:
|
asset_hash=canonical,
|
||||||
return None
|
name=_sanitize_filename(
|
||||||
|
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||||
result = _register_existing_asset(
|
),
|
||||||
asset_hash=canonical,
|
user_metadata=user_metadata or {},
|
||||||
name=_sanitize_filename(
|
tags=tags or [],
|
||||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
tag_origin="manual",
|
||||||
),
|
owner_id=owner_id,
|
||||||
user_metadata=user_metadata or {},
|
mime_type=mime_type,
|
||||||
tags=tags or [],
|
preview_id=preview_id,
|
||||||
tag_origin="manual",
|
)
|
||||||
owner_id=owner_id,
|
except ValueError:
|
||||||
)
|
logging.warning("create_from_hash: no asset found for hash %s", canonical)
|
||||||
|
return None
|
||||||
|
|
||||||
return UploadResult(
|
return UploadResult(
|
||||||
ref=result.ref,
|
ref=result.ref,
|
||||||
|
|||||||
@ -25,7 +25,9 @@ class ReferenceData:
|
|||||||
preview_id: str | None
|
preview_id: str | None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
last_access_time: datetime | None
|
system_metadata: dict[str, Any] | None = None
|
||||||
|
job_id: str | None = None
|
||||||
|
last_access_time: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
|
|||||||
file_path=ref.file_path,
|
file_path=ref.file_path,
|
||||||
user_metadata=ref.user_metadata,
|
user_metadata=ref.user_metadata,
|
||||||
preview_id=ref.preview_id,
|
preview_id=ref.preview_id,
|
||||||
|
system_metadata=ref.system_metadata,
|
||||||
|
job_id=ref.job_id,
|
||||||
created_at=ref.created_at,
|
created_at=ref.created_at,
|
||||||
updated_at=ref.updated_at,
|
updated_at=ref.updated_at,
|
||||||
last_access_time=ref.last_access_time,
|
last_access_time=ref.last_access_time,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Sequence
|
||||||
|
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
AddTagsResult,
|
AddTagsResult,
|
||||||
RemoveTagsResult,
|
RemoveTagsResult,
|
||||||
@ -6,6 +8,7 @@ from app.assets.database.queries import (
|
|||||||
list_tags_with_usage,
|
list_tags_with_usage,
|
||||||
remove_tags_from_reference,
|
remove_tags_from_reference,
|
||||||
)
|
)
|
||||||
|
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
|
||||||
from app.assets.services.schemas import TagUsage
|
from app.assets.services.schemas import TagUsage
|
||||||
from app.database.db import create_session
|
from app.database.db import create_session
|
||||||
|
|
||||||
@ -73,3 +76,23 @@ def list_tags(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||||
|
|
||||||
|
|
||||||
|
def list_tag_histogram(
|
||||||
|
owner_id: str = "",
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
name_contains: str | None = None,
|
||||||
|
metadata_filter: dict | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> dict[str, int]:
|
||||||
|
with create_session() as session:
|
||||||
|
return list_tag_counts_for_filtered_assets(
|
||||||
|
session,
|
||||||
|
owner_id=owner_id,
|
||||||
|
include_tags=include_tags,
|
||||||
|
exclude_tags=exclude_tags,
|
||||||
|
name_contains=name_contains,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|||||||
@ -1,9 +1,18 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from sqlalchemy import MetaData
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
NAMING_CONVENTION = {
|
||||||
|
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
|
||||||
|
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
|
||||||
|
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||||
|
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||||
|
"pk": "pk_%(table_name)s",
|
||||||
|
}
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
pass
|
metadata = MetaData(naming_convention=NAMING_CONVENTION)
|
||||||
|
|
||||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||||
fields = obj.__table__.columns.keys()
|
fields = obj.__table__.columns.keys()
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import uuid
|
|||||||
import glob
|
import glob
|
||||||
import shutil
|
import shutil
|
||||||
import logging
|
import logging
|
||||||
|
import tempfile
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
@ -377,8 +378,15 @@ class UserManager():
|
|||||||
try:
|
try:
|
||||||
body = await request.read()
|
body = await request.read()
|
||||||
|
|
||||||
with open(path, "wb") as f:
|
dir_name = os.path.dirname(path)
|
||||||
f.write(body)
|
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
|
||||||
|
try:
|
||||||
|
with os.fdopen(fd, "wb") as f:
|
||||||
|
f.write(body)
|
||||||
|
os.replace(tmp_path, path)
|
||||||
|
except:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
raise
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logging.warning(f"Error saving file '{path}': {e}")
|
logging.warning(f"Error saving file '{path}': {e}")
|
||||||
return web.Response(
|
return web.Response(
|
||||||
|
|||||||
@ -149,6 +149,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
|||||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||||
|
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
|
||||||
|
|
||||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||||
|
|
||||||
@ -262,4 +263,6 @@ else:
|
|||||||
args.fast = set(args.fast)
|
args.fast = set(args.fast)
|
||||||
|
|
||||||
def enables_dynamic_vram():
|
def enables_dynamic_vram():
|
||||||
|
if args.enable_dynamic_vram:
|
||||||
|
return True
|
||||||
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
||||||
|
|||||||
@ -136,16 +136,7 @@ class ResBlock(nn.Module):
|
|||||||
ops.Linear(c_hidden, c),
|
ops.Linear(c_hidden, c),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
|
||||||
|
|
||||||
# Init weights
|
|
||||||
def _basic_init(module):
|
|
||||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
|
||||||
torch.nn.init.xavier_uniform_(module.weight)
|
|
||||||
if module.bias is not None:
|
|
||||||
nn.init.constant_(module.bias, 0)
|
|
||||||
|
|
||||||
self.apply(_basic_init)
|
|
||||||
|
|
||||||
def _norm(self, x, norm):
|
def _norm(self, x, norm):
|
||||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
|||||||
@ -343,6 +343,7 @@ class CrossAttention(nn.Module):
|
|||||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||||
v,
|
v,
|
||||||
heads=self.num_heads,
|
heads=self.num_heads,
|
||||||
|
low_precision_attention=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = self.out_proj(x)
|
out = self.out_proj(x)
|
||||||
@ -412,6 +413,7 @@ class Attention(nn.Module):
|
|||||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||||
value,
|
value,
|
||||||
heads=self.num_heads,
|
heads=self.num_heads,
|
||||||
|
low_precision_attention=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
|
|||||||
@ -65,9 +65,13 @@ class CausalConv3d(nn.Module):
|
|||||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||||
|
|
||||||
x = torch.cat(pieces, dim=2)
|
x = torch.cat(pieces, dim=2)
|
||||||
|
del pieces
|
||||||
|
del cached
|
||||||
|
|
||||||
if needs_caching:
|
if needs_caching:
|
||||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
||||||
|
elif is_end:
|
||||||
|
self.temporal_cache_state[tid] = (None, True)
|
||||||
|
|
||||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||||
|
|
||||||
|
|||||||
@ -297,7 +297,23 @@ class Encoder(nn.Module):
|
|||||||
module.temporal_cache_state.pop(tid, None)
|
module.temporal_cache_state.pop(tid, None)
|
||||||
|
|
||||||
|
|
||||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
||||||
|
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
|
||||||
|
MIN_CHUNK_SIZE = 32 * 1024 ** 2
|
||||||
|
MAX_CHUNK_SIZE = 128 * 1024 ** 2
|
||||||
|
|
||||||
|
def get_max_chunk_size(device: torch.device) -> int:
|
||||||
|
total_memory = comfy.model_management.get_total_memory(dev=device)
|
||||||
|
|
||||||
|
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
|
||||||
|
return MIN_CHUNK_SIZE
|
||||||
|
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
|
||||||
|
return MAX_CHUNK_SIZE
|
||||||
|
|
||||||
|
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
|
||||||
|
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
|
||||||
|
)
|
||||||
|
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
@ -525,8 +541,11 @@ class Decoder(nn.Module):
|
|||||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||||
|
|
||||||
output = []
|
output = []
|
||||||
|
max_chunk_size = get_max_chunk_size(sample.device)
|
||||||
|
|
||||||
def run_up(idx, sample, ended):
|
def run_up(idx, sample_ref, ended):
|
||||||
|
sample = sample_ref[0]
|
||||||
|
sample_ref[0] = None
|
||||||
if idx >= len(self.up_blocks):
|
if idx >= len(self.up_blocks):
|
||||||
sample = self.conv_norm_out(sample)
|
sample = self.conv_norm_out(sample)
|
||||||
if timestep_shift_scale is not None:
|
if timestep_shift_scale is not None:
|
||||||
@ -554,13 +573,21 @@ class Decoder(nn.Module):
|
|||||||
return
|
return
|
||||||
|
|
||||||
total_bytes = sample.numel() * sample.element_size()
|
total_bytes = sample.numel() * sample.element_size()
|
||||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
|
||||||
|
|
||||||
for chunk_idx, sample1 in enumerate(samples):
|
if num_chunks == 1:
|
||||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||||
|
next_sample_ref = [sample]
|
||||||
|
del sample
|
||||||
|
run_up(idx + 1, next_sample_ref, ended)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||||
|
|
||||||
run_up(0, sample, True)
|
for chunk_idx, sample1 in enumerate(samples):
|
||||||
|
run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1)
|
||||||
|
|
||||||
|
run_up(0, [sample], True)
|
||||||
sample = torch.cat(output, dim=2)
|
sample = torch.cat(output, dim=2)
|
||||||
|
|
||||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
|
|||||||
@ -99,7 +99,7 @@ class Resample(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.resample = nn.Identity()
|
self.resample = nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||||
b, c, t, h, w = x.size()
|
b, c, t, h, w = x.size()
|
||||||
if self.mode == 'upsample3d':
|
if self.mode == 'upsample3d':
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
@ -109,22 +109,7 @@ class Resample(nn.Module):
|
|||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
|
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if cache_x.shape[2] < 2 and feat_cache[
|
|
||||||
idx] is not None and feat_cache[idx] != 'Rep':
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
if cache_x.shape[2] < 2 and feat_cache[
|
|
||||||
idx] is not None and feat_cache[idx] == 'Rep':
|
|
||||||
cache_x = torch.cat([
|
|
||||||
torch.zeros_like(cache_x).to(cache_x.device),
|
|
||||||
cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
if feat_cache[idx] == 'Rep':
|
if feat_cache[idx] == 'Rep':
|
||||||
x = self.time_conv(x)
|
x = self.time_conv(x)
|
||||||
else:
|
else:
|
||||||
@ -145,19 +130,24 @@ class Resample(nn.Module):
|
|||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
if feat_cache[idx] is None:
|
if feat_cache[idx] is None:
|
||||||
feat_cache[idx] = x.clone()
|
feat_cache[idx] = x
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
cache_x = x[:, :, -1:, :, :].clone()
|
cache_x = x[:, :, -1:, :, :]
|
||||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
|
||||||
# # cache last frame of last two chunk
|
|
||||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
|
||||||
|
|
||||||
x = self.time_conv(
|
x = self.time_conv(
|
||||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
|
||||||
|
deferred_x = feat_cache[idx + 1]
|
||||||
|
if deferred_x is not None:
|
||||||
|
x = torch.cat([deferred_x, x], 2)
|
||||||
|
feat_cache[idx + 1] = None
|
||||||
|
|
||||||
|
if x.shape[2] == 1 and not final:
|
||||||
|
feat_cache[idx + 1] = x
|
||||||
|
x = None
|
||||||
|
|
||||||
|
feat_idx[0] += 2
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -177,19 +167,12 @@ class ResidualBlock(nn.Module):
|
|||||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||||
if in_dim != out_dim else nn.Identity()
|
if in_dim != out_dim else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||||
old_x = x
|
old_x = x
|
||||||
for layer in self.residual:
|
for layer in self.residual:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -213,7 +196,7 @@ class AttentionBlock(nn.Module):
|
|||||||
self.proj = ops.Conv2d(dim, dim, 1)
|
self.proj = ops.Conv2d(dim, dim, 1)
|
||||||
self.optimized_attention = vae_attention()
|
self.optimized_attention = vae_attention()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||||
identity = x
|
identity = x
|
||||||
b, c, t, h, w = x.size()
|
b, c, t, h, w = x.size()
|
||||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||||
@ -283,17 +266,10 @@ class Encoder3d(nn.Module):
|
|||||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = self.conv1(x, feat_cache[idx])
|
x = self.conv1(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -303,14 +279,16 @@ class Encoder3d(nn.Module):
|
|||||||
## downsamples
|
## downsamples
|
||||||
for layer in self.downsamples:
|
for layer in self.downsamples:
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx, final=final)
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## middle
|
## middle
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx, final=final)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
@ -318,14 +296,7 @@ class Encoder3d(nn.Module):
|
|||||||
for layer in self.head:
|
for layer in self.head:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = layer(x, feat_cache[idx])
|
x = layer(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -393,14 +364,7 @@ class Decoder3d(nn.Module):
|
|||||||
## conv1
|
## conv1
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = self.conv1(x, feat_cache[idx])
|
x = self.conv1(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -409,42 +373,56 @@ class Decoder3d(nn.Module):
|
|||||||
|
|
||||||
## middle
|
## middle
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
|
||||||
x = layer(x, feat_cache, feat_idx)
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
## upsamples
|
|
||||||
for layer in self.upsamples:
|
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## head
|
out_chunks = []
|
||||||
for layer in self.head:
|
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
def run_up(layer_idx, x_ref, feat_idx):
|
||||||
idx = feat_idx[0]
|
x = x_ref[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
x_ref[0] = None
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
if layer_idx >= len(self.upsamples):
|
||||||
# cache last frame of last two chunk
|
for layer in self.head:
|
||||||
cache_x = torch.cat([
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
cache_x.device), cache_x
|
x = layer(x, feat_cache[feat_idx[0]])
|
||||||
],
|
feat_cache[feat_idx[0]] = cache_x
|
||||||
dim=2)
|
feat_idx[0] += 1
|
||||||
x = layer(x, feat_cache[idx])
|
else:
|
||||||
feat_cache[idx] = cache_x
|
x = layer(x)
|
||||||
feat_idx[0] += 1
|
out_chunks.append(x)
|
||||||
|
return
|
||||||
|
|
||||||
|
layer = self.upsamples[layer_idx]
|
||||||
|
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 1:
|
||||||
|
for frame_idx in range(x.shape[2]):
|
||||||
|
run_up(
|
||||||
|
layer_idx,
|
||||||
|
[x[:, :, frame_idx:frame_idx + 1, :, :]],
|
||||||
|
feat_idx.copy(),
|
||||||
|
)
|
||||||
|
del x
|
||||||
|
return
|
||||||
|
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
return x
|
|
||||||
|
next_x_ref = [x]
|
||||||
|
del x
|
||||||
|
run_up(layer_idx + 1, next_x_ref, feat_idx)
|
||||||
|
|
||||||
|
run_up(0, [x], feat_idx)
|
||||||
|
return out_chunks
|
||||||
|
|
||||||
|
|
||||||
def count_conv3d(model):
|
def count_cache_layers(model):
|
||||||
count = 0
|
count = 0
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, CausalConv3d):
|
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
@ -482,11 +460,12 @@ class WanVAE(nn.Module):
|
|||||||
conv_idx = [0]
|
conv_idx = [0]
|
||||||
## cache
|
## cache
|
||||||
t = x.shape[2]
|
t = x.shape[2]
|
||||||
iter_ = 1 + (t - 1) // 4
|
t = 1 + ((t - 1) // 4) * 4
|
||||||
|
iter_ = 1 + (t - 1) // 2
|
||||||
feat_map = None
|
feat_map = None
|
||||||
if iter_ > 1:
|
if iter_ > 1:
|
||||||
feat_map = [None] * count_conv3d(self.encoder)
|
feat_map = [None] * count_cache_layers(self.encoder)
|
||||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
conv_idx = [0]
|
conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@ -496,20 +475,23 @@ class WanVAE(nn.Module):
|
|||||||
feat_idx=conv_idx)
|
feat_idx=conv_idx)
|
||||||
else:
|
else:
|
||||||
out_ = self.encoder(
|
out_ = self.encoder(
|
||||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||||
feat_cache=feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=conv_idx)
|
feat_idx=conv_idx,
|
||||||
|
final=(i == (iter_ - 1)))
|
||||||
|
if out_ is None:
|
||||||
|
continue
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
|
|
||||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
def decode(self, z):
|
def decode(self, z):
|
||||||
conv_idx = [0]
|
|
||||||
# z: [b,c,t,h,w]
|
# z: [b,c,t,h,w]
|
||||||
iter_ = z.shape[2]
|
iter_ = 1 + z.shape[2] // 2
|
||||||
feat_map = None
|
feat_map = None
|
||||||
if iter_ > 1:
|
if iter_ > 1:
|
||||||
feat_map = [None] * count_conv3d(self.decoder)
|
feat_map = [None] * count_cache_layers(self.decoder)
|
||||||
x = self.conv2(z)
|
x = self.conv2(z)
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
conv_idx = [0]
|
conv_idx = [0]
|
||||||
@ -520,8 +502,8 @@ class WanVAE(nn.Module):
|
|||||||
feat_idx=conv_idx)
|
feat_idx=conv_idx)
|
||||||
else:
|
else:
|
||||||
out_ = self.decoder(
|
out_ = self.decoder(
|
||||||
x[:, :, i:i + 1, :, :],
|
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||||
feat_cache=feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=conv_idx)
|
feat_idx=conv_idx)
|
||||||
out = torch.cat([out, out_], 2)
|
out += out_
|
||||||
return out
|
return torch.cat(out, 2)
|
||||||
|
|||||||
@ -541,6 +541,7 @@ class LoadedModel:
|
|||||||
if model.parent is not None:
|
if model.parent is not None:
|
||||||
self._parent_model = weakref.ref(model.parent)
|
self._parent_model = weakref.ref(model.parent)
|
||||||
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
||||||
|
self._patcher_finalizer.atexit = False
|
||||||
|
|
||||||
def _switch_parent(self):
|
def _switch_parent(self):
|
||||||
model = self._parent_model()
|
model = self._parent_model()
|
||||||
@ -587,6 +588,7 @@ class LoadedModel:
|
|||||||
|
|
||||||
self.real_model = weakref.ref(real_model)
|
self.real_model = weakref.ref(real_model)
|
||||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||||
|
self.model_finalizer.atexit = False
|
||||||
return real_model
|
return real_model
|
||||||
|
|
||||||
def should_reload_model(self, force_patch_weights=False):
|
def should_reload_model(self, force_patch_weights=False):
|
||||||
@ -1001,7 +1003,7 @@ def text_encoder_offload_device():
|
|||||||
def text_encoder_device():
|
def text_encoder_device():
|
||||||
if args.gpu_only:
|
if args.gpu_only:
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
|
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled:
|
||||||
if should_use_fp16(prioritize_performance=False):
|
if should_use_fp16(prioritize_performance=False):
|
||||||
return get_torch_device()
|
return get_torch_device()
|
||||||
else:
|
else:
|
||||||
|
|||||||
101
comfy/ops.py
101
comfy/ops.py
@ -776,6 +776,71 @@ from .quant_ops import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantLinearFunc(torch.autograd.Function):
|
||||||
|
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||||
|
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
|
||||||
|
input_shape = input_float.shape
|
||||||
|
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||||
|
|
||||||
|
# Quantize input (same as inference path)
|
||||||
|
if layout_type is not None:
|
||||||
|
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||||
|
else:
|
||||||
|
q_input = inp
|
||||||
|
|
||||||
|
w = weight.detach() if weight.requires_grad else weight
|
||||||
|
b = bias.detach() if bias is not None and bias.requires_grad else bias
|
||||||
|
|
||||||
|
output = torch.nn.functional.linear(q_input, w, b)
|
||||||
|
|
||||||
|
# Restore original input shape
|
||||||
|
if len(input_shape) > 2:
|
||||||
|
output = output.unflatten(0, input_shape[:-1])
|
||||||
|
|
||||||
|
ctx.save_for_backward(input_float, weight)
|
||||||
|
ctx.input_shape = input_shape
|
||||||
|
ctx.has_bias = bias is not None
|
||||||
|
ctx.compute_dtype = compute_dtype
|
||||||
|
ctx.weight_requires_grad = weight.requires_grad
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.autograd.function.once_differentiable
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input_float, weight = ctx.saved_tensors
|
||||||
|
compute_dtype = ctx.compute_dtype
|
||||||
|
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||||
|
|
||||||
|
# Dequantize weight to compute dtype for backward matmul
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight_f = weight.dequantize().to(compute_dtype)
|
||||||
|
else:
|
||||||
|
weight_f = weight.to(compute_dtype)
|
||||||
|
|
||||||
|
# grad_input = grad_output @ weight
|
||||||
|
grad_input = torch.mm(grad_2d, weight_f)
|
||||||
|
if len(ctx.input_shape) > 2:
|
||||||
|
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||||
|
|
||||||
|
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||||
|
grad_weight = None
|
||||||
|
if ctx.weight_requires_grad:
|
||||||
|
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||||
|
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||||
|
|
||||||
|
# grad_bias
|
||||||
|
grad_bias = None
|
||||||
|
if ctx.has_bias:
|
||||||
|
grad_bias = grad_2d.sum(dim=0)
|
||||||
|
|
||||||
|
return grad_input, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||||
class MixedPrecisionOps(manual_cast):
|
class MixedPrecisionOps(manual_cast):
|
||||||
_quant_config = quant_config
|
_quant_config = quant_config
|
||||||
@ -970,10 +1035,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
#If cast needs to apply lora, it should be done in the compute dtype
|
#If cast needs to apply lora, it should be done in the compute dtype
|
||||||
compute_dtype = input.dtype
|
compute_dtype = input.dtype
|
||||||
|
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
_use_quantized = (
|
||||||
|
getattr(self, 'layout_type', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||||
|
if (input.requires_grad and _use_quantized):
|
||||||
|
|
||||||
|
weight, bias, offload_stream = cast_bias_weight(
|
||||||
|
self,
|
||||||
|
input,
|
||||||
|
offloadable=True,
|
||||||
|
compute_dtype=compute_dtype,
|
||||||
|
want_requant=True
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = getattr(self, 'input_scale', None)
|
||||||
|
if scale is not None:
|
||||||
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||||
|
|
||||||
|
output = QuantLinearFunc.apply(
|
||||||
|
input, weight, bias, self.layout_type, scale, compute_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return output
|
||||||
|
|
||||||
|
# Inference path (unchanged)
|
||||||
|
if _use_quantized:
|
||||||
|
|
||||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||||
@ -1021,7 +1113,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
for key, param in self._parameters.items():
|
for key, param in self._parameters.items():
|
||||||
if param is None:
|
if param is None:
|
||||||
continue
|
continue
|
||||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
p = fn(param)
|
||||||
|
if p.is_inference():
|
||||||
|
p = p.clone()
|
||||||
|
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||||
for key, buf in self._buffers.items():
|
for key, buf in self._buffers.items():
|
||||||
if buf is not None:
|
if buf is not None:
|
||||||
self._buffers[key] = fn(buf)
|
self._buffers[key] = fn(buf)
|
||||||
|
|||||||
@ -455,7 +455,7 @@ class VAE:
|
|||||||
self.output_channels = 3
|
self.output_channels = 3
|
||||||
self.pad_channel_value = None
|
self.pad_channel_value = None
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = False
|
self.disable_offload = False
|
||||||
self.not_video = False
|
self.not_video = False
|
||||||
@ -952,8 +952,8 @@ class VAE:
|
|||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
|
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True))
|
||||||
if pixel_samples is None:
|
if pixel_samples is None:
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
pixel_samples[x:x+batch_number] = out
|
pixel_samples[x:x+batch_number] = out
|
||||||
|
|||||||
@ -897,6 +897,10 @@ def set_attr(obj, attr, value):
|
|||||||
return prev
|
return prev
|
||||||
|
|
||||||
def set_attr_param(obj, attr, value):
|
def set_attr_param(obj, attr, value):
|
||||||
|
# Clone inference tensors (created under torch.inference_mode) since
|
||||||
|
# their version counter is frozen and nn.Parameter() cannot wrap them.
|
||||||
|
if (not torch.is_inference_mode_enabled()) and value.is_inference():
|
||||||
|
value = value.clone()
|
||||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||||
|
|
||||||
def set_attr_buffer(obj, attr, value):
|
def set_attr_buffer(obj, attr, value):
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import comfy.sampler_helpers
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy_extras.nodes_custom_sampler
|
import comfy_extras.nodes_custom_sampler
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import node_helpers
|
import node_helpers
|
||||||
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
training_dtype=torch.bfloat16,
|
training_dtype=torch.bfloat16,
|
||||||
real_dataset=None,
|
real_dataset=None,
|
||||||
bucket_latents=None,
|
bucket_latents=None,
|
||||||
|
use_grad_scaler=False,
|
||||||
):
|
):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.bucket_latents: list[torch.Tensor] | None = (
|
self.bucket_latents: list[torch.Tensor] | None = (
|
||||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||||
)
|
)
|
||||||
|
# GradScaler for fp16 training
|
||||||
|
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
||||||
# Precompute bucket offsets and weights for sampling
|
# Precompute bucket offsets and weights for sampling
|
||||||
if bucket_latents is not None:
|
if bucket_latents is not None:
|
||||||
self._init_bucket_data(bucket_latents)
|
self._init_bucket_data(bucket_latents)
|
||||||
@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
batch_sigmas.requires_grad_(True),
|
batch_sigmas.requires_grad_(True),
|
||||||
**batch_extra_args,
|
**batch_extra_args,
|
||||||
)
|
)
|
||||||
loss = self.loss_fn(x0_pred, x0)
|
loss = self.loss_fn(x0_pred.float(), x0.float())
|
||||||
if bwd:
|
if bwd:
|
||||||
bwd_loss = loss / self.grad_acc
|
bwd_loss = loss / self.grad_acc
|
||||||
bwd_loss.backward()
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.scale(bwd_loss).backward()
|
||||||
|
else:
|
||||||
|
bwd_loss.backward()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||||
@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
)
|
)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||||
total_loss.backward()
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.scale(total_loss).backward()
|
||||||
|
else:
|
||||||
|
total_loss.backward()
|
||||||
if self.loss_callback:
|
if self.loss_callback:
|
||||||
self.loss_callback(total_loss.item())
|
self.loss_callback(total_loss.item())
|
||||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||||
@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||||
|
|
||||||
if (i + 1) % self.grad_acc == 0:
|
if (i + 1) % self.grad_acc == 0:
|
||||||
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.unscale_(self.optimizer)
|
||||||
for param_groups in self.optimizer.param_groups:
|
for param_groups in self.optimizer.param_groups:
|
||||||
for param in param_groups["params"]:
|
for param in param_groups["params"]:
|
||||||
if param.grad is None:
|
if param.grad is None:
|
||||||
continue
|
continue
|
||||||
param.grad.data = param.grad.data.to(param.data.dtype)
|
param.grad.data = param.grad.data.to(param.data.dtype)
|
||||||
self.optimizer.step()
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.step(self.optimizer)
|
||||||
|
self.grad_scaler.update()
|
||||||
|
else:
|
||||||
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
ui_pbar.update(1)
|
ui_pbar.update(1)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"training_dtype",
|
"training_dtype",
|
||||||
options=["bf16", "fp32"],
|
options=["bf16", "fp32", "none"],
|
||||||
default="bf16",
|
default="bf16",
|
||||||
tooltip="The dtype to use for training.",
|
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
|
||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"lora_dtype",
|
"lora_dtype",
|
||||||
@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
io.Boolean.Input(
|
io.Boolean.Input(
|
||||||
"offloading",
|
"offloading",
|
||||||
default=False,
|
default=False,
|
||||||
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
tooltip="Offload model weights to CPU during training to save GPU memory.",
|
||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"existing_lora",
|
"existing_lora",
|
||||||
@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
|
|
||||||
# Setup model and dtype
|
# Setup model and dtype
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
use_grad_scaler = False
|
||||||
|
if training_dtype != "none":
|
||||||
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
else:
|
||||||
|
# Detect model's native dtype for autocast
|
||||||
|
model_dtype = mp.model.get_dtype()
|
||||||
|
if model_dtype == torch.float16:
|
||||||
|
dtype = torch.float16
|
||||||
|
use_grad_scaler = True
|
||||||
|
# Warn about fp16 accumulation instability during training
|
||||||
|
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||||
|
logging.warning(
|
||||||
|
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||||
|
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||||
|
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||||
|
dtype = torch.bfloat16
|
||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
|
||||||
|
|
||||||
if mp.is_dynamic():
|
|
||||||
if not bypass_mode:
|
|
||||||
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
|
||||||
bypass_mode = True
|
|
||||||
offloading = True
|
|
||||||
elif offloading:
|
|
||||||
if not bypass_mode:
|
|
||||||
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
# Prepare latents and compute counts
|
||||||
|
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||||
latents, dtype, bucket_mode
|
latents, latents_dtype, bucket_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate and expand conditioning
|
# Validate and expand conditioning
|
||||||
@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
training_dtype=dtype,
|
training_dtype=dtype,
|
||||||
bucket_latents=latents,
|
bucket_latents=latents,
|
||||||
|
use_grad_scaler=use_grad_scaler,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_sampler = TrainSampler(
|
train_sampler = TrainSampler(
|
||||||
@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
training_dtype=dtype,
|
training_dtype=dtype,
|
||||||
real_dataset=latents if multi_res else None,
|
real_dataset=latents if multi_res else None,
|
||||||
|
use_grad_scaler=use_grad_scaler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup guider
|
# Setup guider
|
||||||
@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
|
|||||||
io.Int.Input(
|
io.Int.Input(
|
||||||
"steps",
|
"steps",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
|
|||||||
4
main.py
4
main.py
@ -206,8 +206,8 @@ import hook_breaker_ac10a0
|
|||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
|
|
||||||
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
|
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
||||||
if comfy.model_management.torch_version_numeric < (2, 8):
|
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
||||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||||
if args.verbose == 'DEBUG':
|
if args.verbose == 'DEBUG':
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
comfyui_manager==4.1b5
|
comfyui_manager==4.1b6
|
||||||
2
nodes.py
2
nodes.py
@ -952,7 +952,7 @@ class UNETLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "load_unet"
|
FUNCTION = "load_unet"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.41.20
|
comfyui-frontend-package==1.41.20
|
||||||
comfyui-workflow-templates==0.9.21
|
comfyui-workflow-templates==0.9.26
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
|
|||||||
79
server.py
79
server.py
@ -35,6 +35,8 @@ from app.frontend_management import FrontendManager, parse_version
|
|||||||
from comfy_api.internal import _ComfyNodeInternal
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
from app.assets.seeder import asset_seeder
|
from app.assets.seeder import asset_seeder
|
||||||
from app.assets.api.routes import register_assets_routes
|
from app.assets.api.routes import register_assets_routes
|
||||||
|
from app.assets.services.ingest import register_file_in_place
|
||||||
|
from app.assets.services.asset_management import resolve_hash_to_path
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
@ -419,7 +421,24 @@ class PromptServer():
|
|||||||
with open(filepath, "wb") as f:
|
with open(filepath, "wb") as f:
|
||||||
f.write(image.file.read())
|
f.write(image.file.read())
|
||||||
|
|
||||||
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
|
||||||
|
|
||||||
|
if args.enable_assets:
|
||||||
|
try:
|
||||||
|
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
|
||||||
|
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
|
||||||
|
resp["asset"] = {
|
||||||
|
"id": result.ref.id,
|
||||||
|
"name": result.ref.name,
|
||||||
|
"asset_hash": result.asset.hash,
|
||||||
|
"size": result.asset.size_bytes,
|
||||||
|
"mime_type": result.asset.mime_type,
|
||||||
|
"tags": result.tags,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
logging.warning("Failed to register uploaded image as asset", exc_info=True)
|
||||||
|
|
||||||
|
return web.json_response(resp)
|
||||||
else:
|
else:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
@ -479,30 +498,43 @@ class PromptServer():
|
|||||||
async def view_image(request):
|
async def view_image(request):
|
||||||
if "filename" in request.rel_url.query:
|
if "filename" in request.rel_url.query:
|
||||||
filename = request.rel_url.query["filename"]
|
filename = request.rel_url.query["filename"]
|
||||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
|
||||||
|
|
||||||
if not filename:
|
# The frontend's LoadImage combo widget uses asset_hash values
|
||||||
return web.Response(status=400)
|
# (e.g. "blake3:...") as widget values. When litegraph renders the
|
||||||
|
# node preview, it constructs /view?filename=<asset_hash>, so this
|
||||||
|
# endpoint must resolve blake3 hashes to their on-disk file paths.
|
||||||
|
if filename.startswith("blake3:"):
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
result = resolve_hash_to_path(filename, owner_id=owner_id)
|
||||||
|
if result is None:
|
||||||
|
return web.Response(status=404)
|
||||||
|
file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type
|
||||||
|
else:
|
||||||
|
resolved_content_type = None
|
||||||
|
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||||
|
|
||||||
# validation for security: prevent accessing arbitrary path
|
if not filename:
|
||||||
if filename[0] == '/' or '..' in filename:
|
return web.Response(status=400)
|
||||||
return web.Response(status=400)
|
|
||||||
|
|
||||||
if output_dir is None:
|
# validation for security: prevent accessing arbitrary path
|
||||||
type = request.rel_url.query.get("type", "output")
|
if filename[0] == '/' or '..' in filename:
|
||||||
output_dir = folder_paths.get_directory_by_type(type)
|
return web.Response(status=400)
|
||||||
|
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
return web.Response(status=400)
|
type = request.rel_url.query.get("type", "output")
|
||||||
|
output_dir = folder_paths.get_directory_by_type(type)
|
||||||
|
|
||||||
if "subfolder" in request.rel_url.query:
|
if output_dir is None:
|
||||||
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
return web.Response(status=400)
|
||||||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
|
||||||
return web.Response(status=403)
|
|
||||||
output_dir = full_output_dir
|
|
||||||
|
|
||||||
filename = os.path.basename(filename)
|
if "subfolder" in request.rel_url.query:
|
||||||
file = os.path.join(output_dir, filename)
|
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
||||||
|
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||||
|
return web.Response(status=403)
|
||||||
|
output_dir = full_output_dir
|
||||||
|
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
file = os.path.join(output_dir, filename)
|
||||||
|
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
if 'preview' in request.rel_url.query:
|
if 'preview' in request.rel_url.query:
|
||||||
@ -562,8 +594,13 @@ class PromptServer():
|
|||||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||||
else:
|
else:
|
||||||
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
# Use the content type from asset resolution if available,
|
||||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
# otherwise guess from the filename.
|
||||||
|
content_type = (
|
||||||
|
resolved_content_type
|
||||||
|
or mimetypes.guess_type(filename)[0]
|
||||||
|
or 'application/octet-stream'
|
||||||
|
)
|
||||||
|
|
||||||
# For security, force certain mimetypes to download instead of display
|
# For security, force certain mimetypes to download instead of display
|
||||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||||
|
|||||||
57
tests-unit/app_test/test_migrations.py
Normal file
57
tests-unit/app_test/test_migrations.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""Test that Alembic migrations run cleanly on a file-backed SQLite DB.
|
||||||
|
|
||||||
|
This catches problems like unnamed FK constraints that prevent batch-mode
|
||||||
|
drop_constraint from working on real SQLite files (see MB-2).
|
||||||
|
|
||||||
|
Migrations 0001 and 0002 are already shipped, so we only exercise
|
||||||
|
upgrade/downgrade for 0003+.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from alembic import command
|
||||||
|
from alembic.config import Config
|
||||||
|
|
||||||
|
|
||||||
|
# Oldest shipped revision — we upgrade to here as a baseline and never
|
||||||
|
# downgrade past it.
|
||||||
|
_BASELINE = "0002_merge_to_asset_references"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_config(db_path: str) -> Config:
|
||||||
|
root = os.path.join(os.path.dirname(__file__), "../..")
|
||||||
|
config_path = os.path.abspath(os.path.join(root, "alembic.ini"))
|
||||||
|
scripts_path = os.path.abspath(os.path.join(root, "alembic_db"))
|
||||||
|
|
||||||
|
cfg = Config(config_path)
|
||||||
|
cfg.set_main_option("script_location", scripts_path)
|
||||||
|
cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def migration_db(tmp_path):
|
||||||
|
"""Yield an alembic Config pre-upgraded to the baseline revision."""
|
||||||
|
db_path = str(tmp_path / "test_migration.db")
|
||||||
|
cfg = _make_config(db_path)
|
||||||
|
command.upgrade(cfg, _BASELINE)
|
||||||
|
yield cfg
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade_to_head(migration_db):
|
||||||
|
"""Upgrade from baseline to head must succeed on a file-backed DB."""
|
||||||
|
command.upgrade(migration_db, "head")
|
||||||
|
|
||||||
|
|
||||||
|
def test_downgrade_to_baseline(migration_db):
|
||||||
|
"""Upgrade to head then downgrade back to baseline."""
|
||||||
|
command.upgrade(migration_db, "head")
|
||||||
|
command.downgrade(migration_db, _BASELINE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade_downgrade_cycle(migration_db):
|
||||||
|
"""Full cycle: upgrade → downgrade → upgrade again."""
|
||||||
|
command.upgrade(migration_db, "head")
|
||||||
|
command.downgrade(migration_db, _BASELINE)
|
||||||
|
command.upgrade(migration_db, "head")
|
||||||
@ -10,6 +10,7 @@ from app.assets.database.queries import (
|
|||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
upsert_asset,
|
upsert_asset,
|
||||||
bulk_insert_assets,
|
bulk_insert_assets,
|
||||||
|
update_asset_hash_and_mime,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -142,3 +143,45 @@ class TestBulkInsertAssets:
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
assert session.query(Asset).count() == 200
|
assert session.query(Asset).count() == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestMimeTypeImmutability:
|
||||||
|
"""mime_type on Asset is write-once: set on first ingest, never overwritten."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"initial_mime,second_mime,expected_mime",
|
||||||
|
[
|
||||||
|
("image/png", "image/jpeg", "image/png"),
|
||||||
|
(None, "image/png", "image/png"),
|
||||||
|
],
|
||||||
|
ids=["preserves_existing", "fills_null"],
|
||||||
|
)
|
||||||
|
def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime):
|
||||||
|
h = f"blake3:upsert_{initial_mime}_{second_mime}"
|
||||||
|
upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime)
|
||||||
|
assert created is False
|
||||||
|
assert asset.mime_type == expected_mime
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"initial_mime,update_mime,update_hash,expected_mime,expected_hash",
|
||||||
|
[
|
||||||
|
(None, "image/png", None, "image/png", "blake3:upd0"),
|
||||||
|
("image/png", "image/jpeg", None, "image/png", "blake3:upd1"),
|
||||||
|
("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"),
|
||||||
|
],
|
||||||
|
ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"],
|
||||||
|
)
|
||||||
|
def test_update_asset_hash_and_mime_immutability(
|
||||||
|
self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash,
|
||||||
|
):
|
||||||
|
h = expected_hash.removesuffix("_new")
|
||||||
|
asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime)
|
||||||
|
session.add(asset)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash)
|
||||||
|
assert asset.mime_type == expected_mime
|
||||||
|
assert asset.hash == expected_hash
|
||||||
|
|||||||
@ -242,22 +242,24 @@ class TestSetReferencePreview:
|
|||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
preview_asset = _make_asset(session, "preview_hash")
|
preview_asset = _make_asset(session, "preview_hash")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
|
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id)
|
set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
session.refresh(ref)
|
session.refresh(ref)
|
||||||
assert ref.preview_id == preview_asset.id
|
assert ref.preview_id == preview_ref.id
|
||||||
|
|
||||||
def test_clears_preview(self, session: Session):
|
def test_clears_preview(self, session: Session):
|
||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
preview_asset = _make_asset(session, "preview_hash")
|
preview_asset = _make_asset(session, "preview_hash")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
ref.preview_id = preview_asset.id
|
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||||
|
ref.preview_id = preview_ref.id
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id=None)
|
set_reference_preview(session, reference_id=ref.id, preview_reference_id=None)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
session.refresh(ref)
|
session.refresh(ref)
|
||||||
@ -265,15 +267,15 @@ class TestSetReferencePreview:
|
|||||||
|
|
||||||
def test_raises_for_nonexistent_reference(self, session: Session):
|
def test_raises_for_nonexistent_reference(self, session: Session):
|
||||||
with pytest.raises(ValueError, match="not found"):
|
with pytest.raises(ValueError, match="not found"):
|
||||||
set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None)
|
set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None)
|
||||||
|
|
||||||
def test_raises_for_nonexistent_preview(self, session: Session):
|
def test_raises_for_nonexistent_preview(self, session: Session):
|
||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Preview Asset"):
|
with pytest.raises(ValueError, match="Preview AssetReference"):
|
||||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent")
|
set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent")
|
||||||
|
|
||||||
|
|
||||||
class TestInsertReference:
|
class TestInsertReference:
|
||||||
@ -351,13 +353,14 @@ class TestUpdateReferenceTimestamps:
|
|||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
preview_asset = _make_asset(session, "preview_hash")
|
preview_asset = _make_asset(session, "preview_hash")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
|
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
update_reference_timestamps(session, ref, preview_id=preview_asset.id)
|
update_reference_timestamps(session, ref, preview_id=preview_ref.id)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
session.refresh(ref)
|
session.refresh(ref)
|
||||||
assert ref.preview_id == preview_asset.id
|
assert ref.preview_id == preview_ref.id
|
||||||
|
|
||||||
|
|
||||||
class TestSetReferenceMetadata:
|
class TestSetReferenceMetadata:
|
||||||
|
|||||||
@ -20,6 +20,7 @@ def _make_reference(
|
|||||||
asset: Asset,
|
asset: Asset,
|
||||||
name: str,
|
name: str,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
|
system_metadata: dict | None = None,
|
||||||
) -> AssetReference:
|
) -> AssetReference:
|
||||||
now = get_utc_now()
|
now = get_utc_now()
|
||||||
ref = AssetReference(
|
ref = AssetReference(
|
||||||
@ -27,6 +28,7 @@ def _make_reference(
|
|||||||
name=name,
|
name=name,
|
||||||
asset_id=asset.id,
|
asset_id=asset.id,
|
||||||
user_metadata=metadata,
|
user_metadata=metadata,
|
||||||
|
system_metadata=system_metadata,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
last_access_time=now,
|
last_access_time=now,
|
||||||
@ -34,8 +36,10 @@ def _make_reference(
|
|||||||
session.add(ref)
|
session.add(ref)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
if metadata:
|
# Build merged projection: {**system_metadata, **user_metadata}
|
||||||
for key, val in metadata.items():
|
merged = {**(system_metadata or {}), **(metadata or {})}
|
||||||
|
if merged:
|
||||||
|
for key, val in merged.items():
|
||||||
for row in convert_metadata_to_rows(key, val):
|
for row in convert_metadata_to_rows(key, val):
|
||||||
meta_row = AssetReferenceMeta(
|
meta_row = AssetReferenceMeta(
|
||||||
asset_reference_id=ref.id,
|
asset_reference_id=ref.id,
|
||||||
@ -182,3 +186,46 @@ class TestMetadataFilterEmptyDict:
|
|||||||
|
|
||||||
refs, _, total = list_references_page(session, metadata_filter={})
|
refs, _, total = list_references_page(session, metadata_filter={})
|
||||||
assert total == 2
|
assert total == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestSystemMetadataProjection:
|
||||||
|
"""Tests for system_metadata merging into the filter projection."""
|
||||||
|
|
||||||
|
def test_system_metadata_keys_are_filterable(self, session: Session):
|
||||||
|
"""system_metadata keys should appear in the merged projection."""
|
||||||
|
asset = _make_asset(session, "hash1")
|
||||||
|
_make_reference(
|
||||||
|
session, asset, "with_sys",
|
||||||
|
system_metadata={"source": "scanner"},
|
||||||
|
)
|
||||||
|
_make_reference(session, asset, "without_sys")
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
refs, _, total = list_references_page(
|
||||||
|
session, metadata_filter={"source": "scanner"}
|
||||||
|
)
|
||||||
|
assert total == 1
|
||||||
|
assert refs[0].name == "with_sys"
|
||||||
|
|
||||||
|
def test_user_metadata_overrides_system_metadata(self, session: Session):
|
||||||
|
"""user_metadata should win when both have the same key."""
|
||||||
|
asset = _make_asset(session, "hash1")
|
||||||
|
_make_reference(
|
||||||
|
session, asset, "overridden",
|
||||||
|
metadata={"origin": "user_upload"},
|
||||||
|
system_metadata={"origin": "auto_scan"},
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Should match the user value, not the system value
|
||||||
|
refs, _, total = list_references_page(
|
||||||
|
session, metadata_filter={"origin": "user_upload"}
|
||||||
|
)
|
||||||
|
assert total == 1
|
||||||
|
assert refs[0].name == "overridden"
|
||||||
|
|
||||||
|
# Should NOT match the system value (it was overridden)
|
||||||
|
refs, _, total = list_references_page(
|
||||||
|
session, metadata_filter={"origin": "auto_scan"}
|
||||||
|
)
|
||||||
|
assert total == 0
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from app.assets.services import (
|
|||||||
delete_asset_reference,
|
delete_asset_reference,
|
||||||
set_asset_preview,
|
set_asset_preview,
|
||||||
)
|
)
|
||||||
|
from app.assets.services.asset_management import resolve_hash_to_path
|
||||||
|
|
||||||
|
|
||||||
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
|
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
|
||||||
@ -219,31 +220,33 @@ class TestSetAssetPreview:
|
|||||||
asset = _make_asset(session, hash_val="blake3:main")
|
asset = _make_asset(session, hash_val="blake3:main")
|
||||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
|
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||||
ref_id = ref.id
|
ref_id = ref.id
|
||||||
preview_id = preview_asset.id
|
preview_ref_id = preview_ref.id
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_asset_preview(
|
set_asset_preview(
|
||||||
reference_id=ref_id,
|
reference_id=ref_id,
|
||||||
preview_asset_id=preview_id,
|
preview_reference_id=preview_ref_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify by re-fetching from DB
|
# Verify by re-fetching from DB
|
||||||
session.expire_all()
|
session.expire_all()
|
||||||
updated_ref = session.get(AssetReference, ref_id)
|
updated_ref = session.get(AssetReference, ref_id)
|
||||||
assert updated_ref.preview_id == preview_id
|
assert updated_ref.preview_id == preview_ref_id
|
||||||
|
|
||||||
def test_clears_preview(self, mock_create_session, session: Session):
|
def test_clears_preview(self, mock_create_session, session: Session):
|
||||||
asset = _make_asset(session)
|
asset = _make_asset(session)
|
||||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
ref.preview_id = preview_asset.id
|
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||||
|
ref.preview_id = preview_ref.id
|
||||||
ref_id = ref.id
|
ref_id = ref.id
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_asset_preview(
|
set_asset_preview(
|
||||||
reference_id=ref_id,
|
reference_id=ref_id,
|
||||||
preview_asset_id=None,
|
preview_reference_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify by re-fetching from DB
|
# Verify by re-fetching from DB
|
||||||
@ -263,6 +266,45 @@ class TestSetAssetPreview:
|
|||||||
with pytest.raises(PermissionError, match="not owner"):
|
with pytest.raises(PermissionError, match="not owner"):
|
||||||
set_asset_preview(
|
set_asset_preview(
|
||||||
reference_id=ref.id,
|
reference_id=ref.id,
|
||||||
preview_asset_id=None,
|
preview_reference_id=None,
|
||||||
owner_id="user2",
|
owner_id="user2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveHashToPath:
|
||||||
|
def test_returns_none_for_unknown_hash(self, mock_create_session):
|
||||||
|
result = resolve_hash_to_path("blake3:" + "a" * 64)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"ref_owner, query_owner, expect_found",
|
||||||
|
[
|
||||||
|
("user1", "user1", True),
|
||||||
|
("user1", "user2", False),
|
||||||
|
("", "anyone", True),
|
||||||
|
("", "", True),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"owner_sees_own_ref",
|
||||||
|
"other_owner_blocked",
|
||||||
|
"ownerless_visible_to_anyone",
|
||||||
|
"ownerless_visible_to_empty",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_owner_visibility(
|
||||||
|
self, ref_owner, query_owner, expect_found,
|
||||||
|
mock_create_session, session: Session, temp_dir,
|
||||||
|
):
|
||||||
|
f = temp_dir / "file.bin"
|
||||||
|
f.write_bytes(b"data")
|
||||||
|
asset = _make_asset(session, hash_val="blake3:" + "b" * 64)
|
||||||
|
ref = _make_reference(session, asset, name="file.bin", owner_id=ref_owner)
|
||||||
|
ref.file_path = str(f)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = resolve_hash_to_path(asset.hash, owner_id=query_owner)
|
||||||
|
if expect_found:
|
||||||
|
assert result is not None
|
||||||
|
assert result.abs_path == str(f)
|
||||||
|
else:
|
||||||
|
assert result is None
|
||||||
|
|||||||
@ -113,11 +113,19 @@ class TestIngestFileFromPath:
|
|||||||
file_path = temp_dir / "with_preview.bin"
|
file_path = temp_dir / "with_preview.bin"
|
||||||
file_path.write_bytes(b"data")
|
file_path.write_bytes(b"data")
|
||||||
|
|
||||||
# Create a preview asset first
|
# Create a preview asset and reference
|
||||||
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
|
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
|
||||||
session.add(preview_asset)
|
session.add(preview_asset)
|
||||||
|
session.flush()
|
||||||
|
from app.assets.helpers import get_utc_now
|
||||||
|
now = get_utc_now()
|
||||||
|
preview_ref = AssetReference(
|
||||||
|
asset_id=preview_asset.id, name="preview.png", owner_id="",
|
||||||
|
created_at=now, updated_at=now, last_access_time=now,
|
||||||
|
)
|
||||||
|
session.add(preview_ref)
|
||||||
session.commit()
|
session.commit()
|
||||||
preview_id = preview_asset.id
|
preview_id = preview_ref.id
|
||||||
|
|
||||||
result = _ingest_file_from_path(
|
result = _ingest_file_from_path(
|
||||||
abs_path=str(file_path),
|
abs_path=str(file_path),
|
||||||
|
|||||||
123
tests-unit/assets_test/services/test_tag_histogram.py
Normal file
123
tests-unit/assets_test/services/test_tag_histogram.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
"""Tests for list_tag_histogram service function."""
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.assets.database.models import Asset, AssetReference
|
||||||
|
from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference
|
||||||
|
from app.assets.helpers import get_utc_now
|
||||||
|
from app.assets.services.tagging import list_tag_histogram
|
||||||
|
|
||||||
|
|
||||||
|
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
|
||||||
|
asset = Asset(hash=hash_val, size_bytes=1024)
|
||||||
|
session.add(asset)
|
||||||
|
session.flush()
|
||||||
|
return asset
|
||||||
|
|
||||||
|
|
||||||
|
def _make_reference(
|
||||||
|
session: Session,
|
||||||
|
asset: Asset,
|
||||||
|
name: str = "test",
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> AssetReference:
|
||||||
|
now = get_utc_now()
|
||||||
|
ref = AssetReference(
|
||||||
|
owner_id=owner_id,
|
||||||
|
name=name,
|
||||||
|
asset_id=asset.id,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
last_access_time=now,
|
||||||
|
)
|
||||||
|
session.add(ref)
|
||||||
|
session.flush()
|
||||||
|
return ref
|
||||||
|
|
||||||
|
|
||||||
|
class TestListTagHistogram:
|
||||||
|
def test_returns_counts_for_all_tags(self, mock_create_session, session: Session):
|
||||||
|
ensure_tags_exist(session, ["alpha", "beta"])
|
||||||
|
a1 = _make_asset(session, "blake3:aaa")
|
||||||
|
r1 = _make_reference(session, a1, name="r1")
|
||||||
|
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha", "beta"])
|
||||||
|
|
||||||
|
a2 = _make_asset(session, "blake3:bbb")
|
||||||
|
r2 = _make_reference(session, a2, name="r2")
|
||||||
|
add_tags_to_reference(session, reference_id=r2.id, tags=["alpha"])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram()
|
||||||
|
|
||||||
|
assert result["alpha"] == 2
|
||||||
|
assert result["beta"] == 1
|
||||||
|
|
||||||
|
def test_empty_when_no_assets(self, mock_create_session, session: Session):
|
||||||
|
ensure_tags_exist(session, ["unused"])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram()
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_include_tags_filter(self, mock_create_session, session: Session):
|
||||||
|
ensure_tags_exist(session, ["models", "loras", "input"])
|
||||||
|
a1 = _make_asset(session, "blake3:aaa")
|
||||||
|
r1 = _make_reference(session, a1, name="r1")
|
||||||
|
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
|
||||||
|
|
||||||
|
a2 = _make_asset(session, "blake3:bbb")
|
||||||
|
r2 = _make_reference(session, a2, name="r2")
|
||||||
|
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram(include_tags=["models"])
|
||||||
|
|
||||||
|
# Only r1 has "models", so only its tags appear
|
||||||
|
assert "models" in result
|
||||||
|
assert "loras" in result
|
||||||
|
assert "input" not in result
|
||||||
|
|
||||||
|
def test_exclude_tags_filter(self, mock_create_session, session: Session):
|
||||||
|
ensure_tags_exist(session, ["models", "loras", "input"])
|
||||||
|
a1 = _make_asset(session, "blake3:aaa")
|
||||||
|
r1 = _make_reference(session, a1, name="r1")
|
||||||
|
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
|
||||||
|
|
||||||
|
a2 = _make_asset(session, "blake3:bbb")
|
||||||
|
r2 = _make_reference(session, a2, name="r2")
|
||||||
|
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram(exclude_tags=["models"])
|
||||||
|
|
||||||
|
# r1 excluded, only r2's tags remain
|
||||||
|
assert "input" in result
|
||||||
|
assert "loras" not in result
|
||||||
|
|
||||||
|
def test_name_contains_filter(self, mock_create_session, session: Session):
|
||||||
|
ensure_tags_exist(session, ["alpha", "beta"])
|
||||||
|
a1 = _make_asset(session, "blake3:aaa")
|
||||||
|
r1 = _make_reference(session, a1, name="my_model.safetensors")
|
||||||
|
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha"])
|
||||||
|
|
||||||
|
a2 = _make_asset(session, "blake3:bbb")
|
||||||
|
r2 = _make_reference(session, a2, name="picture.png")
|
||||||
|
add_tags_to_reference(session, reference_id=r2.id, tags=["beta"])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram(name_contains="model")
|
||||||
|
|
||||||
|
assert "alpha" in result
|
||||||
|
assert "beta" not in result
|
||||||
|
|
||||||
|
def test_limit_caps_results(self, mock_create_session, session: Session):
|
||||||
|
tags = [f"tag{i}" for i in range(10)]
|
||||||
|
ensure_tags_exist(session, tags)
|
||||||
|
a = _make_asset(session, "blake3:aaa")
|
||||||
|
r = _make_reference(session, a, name="r1")
|
||||||
|
add_tags_to_reference(session, reference_id=r.id, tags=tags)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram(limit=3)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
@ -243,6 +243,15 @@ def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
|
|||||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):
|
||||||
|
files = {"file": ("notags.bin", b"A" * 64, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps([]), "name": "notags.bin", "user_metadata": json.dumps({})}
|
||||||
|
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 400
|
||||||
|
assert body["error"]["code"] == "INVALID_BODY"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("root", ["input", "output"])
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
def test_duplicate_upload_same_display_name_does_not_clobber(
|
def test_duplicate_upload_same_display_name_does_not_clobber(
|
||||||
root: str,
|
root: str,
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user