mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-09 00:32:31 +08:00
Merge branch 'master' into mps-text-encoder-device
This commit is contained in:
commit
46d76508c9
103
.github/scripts/check-ai-co-authors.sh
vendored
Executable file
103
.github/scripts/check-ai-co-authors.sh
vendored
Executable file
@ -0,0 +1,103 @@
|
||||
#!/usr/bin/env bash
|
||||
# Checks pull request commits for AI agent Co-authored-by trailers.
|
||||
# Exits non-zero when any are found and prints fix instructions.
|
||||
set -euo pipefail
|
||||
|
||||
base_sha="${1:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||
head_sha="${2:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||
|
||||
# Known AI coding-agent trailer patterns (case-insensitive).
|
||||
# Each entry is an extended-regex fragment matched against Co-authored-by lines.
|
||||
AGENT_PATTERNS=(
|
||||
# Anthropic — Claude Code / Amp
|
||||
'noreply@anthropic\.com'
|
||||
# Cursor
|
||||
'cursoragent@cursor\.com'
|
||||
# GitHub Copilot
|
||||
'copilot-swe-agent\[bot\]'
|
||||
'copilot@github\.com'
|
||||
# OpenAI Codex
|
||||
'noreply@openai\.com'
|
||||
'codex@openai\.com'
|
||||
# Aider
|
||||
'aider@aider\.chat'
|
||||
# Google — Gemini / Jules
|
||||
'gemini@google\.com'
|
||||
'jules@google\.com'
|
||||
# Windsurf / Codeium
|
||||
'@codeium\.com'
|
||||
# Devin
|
||||
'devin-ai-integration\[bot\]'
|
||||
'devin@cognition\.ai'
|
||||
'devin@cognition-labs\.com'
|
||||
# Amazon Q Developer
|
||||
'amazon-q-developer'
|
||||
'@amazon\.com.*[Qq].[Dd]eveloper'
|
||||
# Cline
|
||||
'cline-bot'
|
||||
'cline@cline\.ai'
|
||||
# Continue
|
||||
'continue-agent'
|
||||
'continue@continue\.dev'
|
||||
# Sourcegraph
|
||||
'noreply@sourcegraph\.com'
|
||||
# Generic catch-alls for common agent name patterns
|
||||
'Co-authored-by:.*\b[Cc]laude\b'
|
||||
'Co-authored-by:.*\b[Cc]opilot\b'
|
||||
'Co-authored-by:.*\b[Cc]ursor\b'
|
||||
'Co-authored-by:.*\b[Cc]odex\b'
|
||||
'Co-authored-by:.*\b[Gg]emini\b'
|
||||
'Co-authored-by:.*\b[Aa]ider\b'
|
||||
'Co-authored-by:.*\b[Dd]evin\b'
|
||||
'Co-authored-by:.*\b[Ww]indsurf\b'
|
||||
'Co-authored-by:.*\b[Cc]line\b'
|
||||
'Co-authored-by:.*\b[Aa]mazon Q\b'
|
||||
'Co-authored-by:.*\b[Jj]ules\b'
|
||||
'Co-authored-by:.*\bOpenCode\b'
|
||||
)
|
||||
|
||||
# Build a single alternation regex from all patterns.
|
||||
regex=""
|
||||
for pattern in "${AGENT_PATTERNS[@]}"; do
|
||||
if [[ -n "$regex" ]]; then
|
||||
regex="${regex}|${pattern}"
|
||||
else
|
||||
regex="$pattern"
|
||||
fi
|
||||
done
|
||||
|
||||
# Collect Co-authored-by lines from every commit in the PR range.
|
||||
violations=""
|
||||
while IFS= read -r sha; do
|
||||
message="$(git log -1 --format='%B' "$sha")"
|
||||
matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
|
||||
if [[ -z "$matched_lines" ]]; then
|
||||
continue
|
||||
fi
|
||||
|
||||
while IFS= read -r line; do
|
||||
if echo "$line" | grep -iqE "$regex"; then
|
||||
short="$(git log -1 --format='%h' "$sha")"
|
||||
violations="${violations} ${short}: ${line}"$'\n'
|
||||
fi
|
||||
done <<< "$matched_lines"
|
||||
done < <(git rev-list "${base_sha}..${head_sha}")
|
||||
|
||||
if [[ -n "$violations" ]]; then
|
||||
echo "::error::AI agent Co-authored-by trailers detected in PR commits."
|
||||
echo ""
|
||||
echo "The following commits contain Co-authored-by trailers from AI coding agents:"
|
||||
echo ""
|
||||
echo "$violations"
|
||||
echo "These trailers should be removed before merging."
|
||||
echo ""
|
||||
echo "To fix, rewrite the commit messages with:"
|
||||
echo " git rebase -i ${base_sha}"
|
||||
echo ""
|
||||
echo "and remove the Co-authored-by lines, then force-push your branch."
|
||||
echo ""
|
||||
echo "If you believe this is a false positive, please open an issue."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
echo "No AI agent Co-authored-by trailers found."
|
||||
19
.github/workflows/check-ai-co-authors.yml
vendored
Normal file
19
.github/workflows/check-ai-co-authors.yml
vendored
Normal file
@ -0,0 +1,19 @@
|
||||
name: Check AI Co-Authors
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
branches: ['*']
|
||||
|
||||
jobs:
|
||||
check-ai-co-authors:
|
||||
name: Check for AI agent co-author trailers
|
||||
runs-on: ubuntu-latest
|
||||
|
||||
steps:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Check commits for AI co-author trailers
|
||||
run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"
|
||||
11
README.md
11
README.md
@ -38,6 +38,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
|
||||
## Get Started
|
||||
|
||||
### Local
|
||||
|
||||
#### [Desktop Application](https://www.comfy.org/download)
|
||||
- The easiest way to get started.
|
||||
- Available on Windows & macOS.
|
||||
@ -49,8 +51,13 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
|
||||
#### [Manual Install](#manual-install-windows-linux)
|
||||
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
|
||||
|
||||
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/)
|
||||
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
### Cloud
|
||||
|
||||
#### [Comfy Cloud](https://www.comfy.org/cloud)
|
||||
- Our official paid cloud version for those who can't afford local hardware.
|
||||
|
||||
## Examples
|
||||
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
|
||||
|
||||
## Features
|
||||
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.
|
||||
|
||||
@ -8,7 +8,7 @@ from alembic import context
|
||||
config = context.config
|
||||
|
||||
|
||||
from app.database.models import Base
|
||||
from app.database.models import Base, NAMING_CONVENTION
|
||||
target_metadata = Base.metadata
|
||||
|
||||
# other values from the config, defined by the needs of env.py,
|
||||
@ -51,7 +51,10 @@ def run_migrations_online() -> None:
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(
|
||||
connection=connection, target_metadata=target_metadata
|
||||
connection=connection,
|
||||
target_metadata=target_metadata,
|
||||
render_as_batch=True,
|
||||
naming_convention=NAMING_CONVENTION,
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
|
||||
267
alembic_db/versions/0002_merge_to_asset_references.py
Normal file
267
alembic_db/versions/0002_merge_to_asset_references.py
Normal file
@ -0,0 +1,267 @@
|
||||
"""
|
||||
Merge AssetInfo and AssetCacheState into unified asset_references table.
|
||||
|
||||
This migration drops old tables and creates the new unified schema.
|
||||
All existing data is discarded.
|
||||
|
||||
Revision ID: 0002_merge_to_asset_references
|
||||
Revises: 0001_assets
|
||||
Create Date: 2025-02-11
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0002_merge_to_asset_references"
|
||||
down_revision = "0001_assets"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Drop old tables (order matters due to FK constraints)
|
||||
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
|
||||
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
|
||||
op.drop_table("asset_info_meta")
|
||||
|
||||
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
|
||||
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
|
||||
op.drop_table("asset_info_tags")
|
||||
|
||||
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
|
||||
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
|
||||
op.drop_table("asset_cache_state")
|
||||
|
||||
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_name", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
|
||||
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
|
||||
op.drop_table("assets_info")
|
||||
|
||||
# Truncate assets table (cascades handled by dropping dependent tables first)
|
||||
op.execute("DELETE FROM assets")
|
||||
|
||||
# Create asset_references table
|
||||
op.create_table(
|
||||
"asset_references",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column(
|
||||
"asset_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("assets.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("file_path", sa.Text(), nullable=True),
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column(
|
||||
"needs_verify",
|
||||
sa.Boolean(),
|
||||
nullable=False,
|
||||
server_default=sa.text("false"),
|
||||
),
|
||||
sa.Column(
|
||||
"is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false")
|
||||
),
|
||||
sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column(
|
||||
"preview_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("assets.id", ondelete="SET NULL"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("deleted_at", sa.DateTime(timezone=False), nullable=True),
|
||||
sa.CheckConstraint(
|
||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||
),
|
||||
sa.CheckConstraint(
|
||||
"enrichment_level >= 0 AND enrichment_level <= 2",
|
||||
name="ck_ar_enrichment_level_range",
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"uq_asset_references_file_path", "asset_references", ["file_path"], unique=True
|
||||
)
|
||||
op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"])
|
||||
op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"])
|
||||
op.create_index("ix_asset_references_name", "asset_references", ["name"])
|
||||
op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"])
|
||||
op.create_index(
|
||||
"ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"]
|
||||
)
|
||||
op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"])
|
||||
op.create_index(
|
||||
"ix_asset_references_last_access_time", "asset_references", ["last_access_time"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_references_owner_name", "asset_references", ["owner_id", "name"]
|
||||
)
|
||||
op.create_index("ix_asset_references_deleted_at", "asset_references", ["deleted_at"])
|
||||
|
||||
# Create asset_reference_tags table
|
||||
op.create_table(
|
||||
"asset_reference_tags",
|
||||
sa.Column(
|
||||
"asset_reference_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"tag_name",
|
||||
sa.String(length=512),
|
||||
sa.ForeignKey("tags.name", ondelete="RESTRICT"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column(
|
||||
"origin", sa.String(length=32), nullable=False, server_default="manual"
|
||||
),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint(
|
||||
"asset_reference_id", "tag_name", name="pk_asset_reference_tags"
|
||||
),
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_tags_asset_reference_id",
|
||||
"asset_reference_tags",
|
||||
["asset_reference_id"],
|
||||
)
|
||||
|
||||
# Create asset_reference_meta table
|
||||
op.create_table(
|
||||
"asset_reference_meta",
|
||||
sa.Column(
|
||||
"asset_reference_id",
|
||||
sa.String(length=36),
|
||||
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint(
|
||||
"asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta"
|
||||
),
|
||||
)
|
||||
op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"])
|
||||
op.create_index(
|
||||
"ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"]
|
||||
)
|
||||
op.create_index(
|
||||
"ix_asset_reference_meta_key_val_bool",
|
||||
"asset_reference_meta",
|
||||
["key", "val_bool"],
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
"""Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema.
|
||||
|
||||
NOTE: Data is not recoverable. The upgrade discards all rows from the old
|
||||
tables and truncates assets. After downgrade the old schema will be empty.
|
||||
A filesystem rescan will repopulate data once the older code is running.
|
||||
"""
|
||||
# Drop new tables (order matters due to FK constraints)
|
||||
op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta")
|
||||
op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta")
|
||||
op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta")
|
||||
op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta")
|
||||
op.drop_table("asset_reference_meta")
|
||||
|
||||
op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags")
|
||||
op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags")
|
||||
op.drop_table("asset_reference_tags")
|
||||
|
||||
op.drop_index("ix_asset_references_deleted_at", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_owner_name", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_last_access_time", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_created_at", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_is_missing", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_name", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_owner_id", table_name="asset_references")
|
||||
op.drop_index("ix_asset_references_asset_id", table_name="asset_references")
|
||||
op.drop_index("uq_asset_references_file_path", table_name="asset_references")
|
||||
op.drop_table("asset_references")
|
||||
|
||||
# Truncate assets (upgrade deleted all rows; downgrade starts fresh too)
|
||||
op.execute("DELETE FROM assets")
|
||||
|
||||
# Recreate old tables from 0001_assets schema
|
||||
op.create_table(
|
||||
"assets_info",
|
||||
sa.Column("id", sa.String(length=36), primary_key=True),
|
||||
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
|
||||
sa.Column("name", sa.String(length=512), nullable=False),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
|
||||
sa.Column("user_metadata", sa.JSON(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
)
|
||||
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
|
||||
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
|
||||
op.create_index("ix_assets_info_name", "assets_info", ["name"])
|
||||
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
|
||||
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
|
||||
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
|
||||
|
||||
op.create_table(
|
||||
"asset_cache_state",
|
||||
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
|
||||
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("file_path", sa.Text(), nullable=False),
|
||||
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
|
||||
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
|
||||
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
)
|
||||
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
|
||||
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
|
||||
|
||||
op.create_table(
|
||||
"asset_info_tags",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
|
||||
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
|
||||
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
|
||||
)
|
||||
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
|
||||
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
|
||||
|
||||
op.create_table(
|
||||
"asset_info_meta",
|
||||
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
|
||||
sa.Column("key", sa.String(length=256), nullable=False),
|
||||
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
|
||||
sa.Column("val_str", sa.String(length=2048), nullable=True),
|
||||
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
|
||||
sa.Column("val_bool", sa.Boolean(), nullable=True),
|
||||
sa.Column("val_json", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
|
||||
)
|
||||
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
|
||||
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
|
||||
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
|
||||
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])
|
||||
98
alembic_db/versions/0003_add_metadata_job_id.py
Normal file
98
alembic_db/versions/0003_add_metadata_job_id.py
Normal file
@ -0,0 +1,98 @@
|
||||
"""
|
||||
Add system_metadata and job_id columns to asset_references.
|
||||
Change preview_id FK from assets.id to asset_references.id.
|
||||
|
||||
Revision ID: 0003_add_metadata_job_id
|
||||
Revises: 0002_merge_to_asset_references
|
||||
Create Date: 2026-03-09
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
from app.database.models import NAMING_CONVENTION
|
||||
|
||||
revision = "0003_add_metadata_job_id"
|
||||
down_revision = "0002_merge_to_asset_references"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("asset_references") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column("system_metadata", sa.JSON(), nullable=True)
|
||||
)
|
||||
batch_op.add_column(
|
||||
sa.Column("job_id", sa.String(length=36), nullable=True)
|
||||
)
|
||||
|
||||
# Change preview_id FK from assets.id to asset_references.id (self-ref).
|
||||
# Existing values are asset-content IDs that won't match reference IDs,
|
||||
# so null them out first.
|
||||
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
|
||||
with op.batch_alter_table(
|
||||
"asset_references", naming_convention=NAMING_CONVENTION
|
||||
) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
"fk_asset_references_preview_id_assets", type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_asset_references_preview_id_asset_references",
|
||||
"asset_references",
|
||||
["preview_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
batch_op.create_index(
|
||||
"ix_asset_references_preview_id", ["preview_id"]
|
||||
)
|
||||
|
||||
# Purge any all-null meta rows before adding the constraint
|
||||
op.execute(
|
||||
"DELETE FROM asset_reference_meta"
|
||||
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
|
||||
)
|
||||
with op.batch_alter_table("asset_reference_meta") as batch_op:
|
||||
batch_op.create_check_constraint(
|
||||
"ck_asset_reference_meta_has_value",
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# SQLite doesn't reflect CHECK constraints, so we must declare it
|
||||
# explicitly via table_args for the batch recreate to find it.
|
||||
# Use the fully-rendered constraint name to avoid the naming convention
|
||||
# doubling the prefix.
|
||||
with op.batch_alter_table(
|
||||
"asset_reference_meta",
|
||||
table_args=[
|
||||
sa.CheckConstraint(
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
name="ck_asset_reference_meta_has_value",
|
||||
),
|
||||
],
|
||||
) as batch_op:
|
||||
batch_op.drop_constraint(
|
||||
"ck_asset_reference_meta_has_value", type_="check"
|
||||
)
|
||||
|
||||
with op.batch_alter_table(
|
||||
"asset_references", naming_convention=NAMING_CONVENTION
|
||||
) as batch_op:
|
||||
batch_op.drop_index("ix_asset_references_preview_id")
|
||||
batch_op.drop_constraint(
|
||||
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
|
||||
)
|
||||
batch_op.create_foreign_key(
|
||||
"fk_asset_references_preview_id_assets",
|
||||
"assets",
|
||||
["preview_id"],
|
||||
["id"],
|
||||
ondelete="SET NULL",
|
||||
)
|
||||
|
||||
with op.batch_alter_table("asset_references") as batch_op:
|
||||
batch_op.drop_column("job_id")
|
||||
batch_op.drop_column("system_metadata")
|
||||
File diff suppressed because it is too large
Load Diff
@ -1,6 +1,8 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Literal
|
||||
|
||||
from app.assets.helpers import validate_blake3_hash
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
@ -10,6 +12,43 @@ from pydantic import (
|
||||
model_validator,
|
||||
)
|
||||
|
||||
|
||||
class UploadError(Exception):
|
||||
"""Error during upload parsing with HTTP status and code."""
|
||||
|
||||
def __init__(self, status: int, code: str, message: str):
|
||||
super().__init__(message)
|
||||
self.status = status
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
class AssetValidationError(Exception):
|
||||
"""Validation error in asset processing (invalid tags, metadata, etc.)."""
|
||||
|
||||
def __init__(self, code: str, message: str):
|
||||
super().__init__(message)
|
||||
self.code = code
|
||||
self.message = message
|
||||
|
||||
|
||||
@dataclass
|
||||
class ParsedUpload:
|
||||
"""Result of parsing a multipart upload request."""
|
||||
|
||||
file_present: bool
|
||||
file_written: int
|
||||
file_client_name: str | None
|
||||
tmp_path: str | None
|
||||
tags_raw: list[str]
|
||||
provided_name: str | None
|
||||
user_metadata_raw: str | None
|
||||
provided_hash: str | None
|
||||
provided_hash_exists: bool | None
|
||||
provided_mime_type: str | None = None
|
||||
provided_preview_id: str | None = None
|
||||
|
||||
|
||||
class ListAssetsQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
@ -21,7 +60,9 @@ class ListAssetsQuery(BaseModel):
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
|
||||
"created_at"
|
||||
)
|
||||
order: Literal["asc", "desc"] = "desc"
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@ -59,11 +100,17 @@ class ListAssetsQuery(BaseModel):
|
||||
class UpdateAssetBody(BaseModel):
|
||||
name: str | None = None
|
||||
user_metadata: dict[str, Any] | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _at_least_one(self):
|
||||
if self.name is None and self.user_metadata is None:
|
||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
||||
def _validate_at_least_one_field(self):
|
||||
if all(
|
||||
v is None
|
||||
for v in (self.name, self.user_metadata, self.preview_id)
|
||||
):
|
||||
raise ValueError(
|
||||
"Provide at least one of: name, user_metadata, preview_id."
|
||||
)
|
||||
return self
|
||||
|
||||
|
||||
@ -71,26 +118,20 @@ class CreateFromHashBody(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
hash: str
|
||||
name: str
|
||||
name: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
mime_type: str | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
|
||||
@field_validator("hash")
|
||||
@classmethod
|
||||
def _require_blake3(cls, v):
|
||||
s = (v or "").strip().lower()
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return s
|
||||
return validate_blake3_hash(v or "")
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
def _tags_norm(cls, v):
|
||||
def _normalize_tags_field(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, list):
|
||||
@ -107,6 +148,44 @@ class CreateFromHashBody(BaseModel):
|
||||
return []
|
||||
|
||||
|
||||
class TagsRefineQuery(BaseModel):
|
||||
include_tags: list[str] = Field(default_factory=list)
|
||||
exclude_tags: list[str] = Field(default_factory=list)
|
||||
name_contains: str | None = None
|
||||
metadata_filter: dict[str, Any] | None = None
|
||||
limit: conint(ge=1, le=1000) = 100
|
||||
|
||||
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||
@classmethod
|
||||
def _split_csv_tags(cls, v):
|
||||
if v is None:
|
||||
return []
|
||||
if isinstance(v, str):
|
||||
return [t.strip() for t in v.split(",") if t.strip()]
|
||||
if isinstance(v, list):
|
||||
out: list[str] = []
|
||||
for item in v:
|
||||
if isinstance(item, str):
|
||||
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||
return out
|
||||
return v
|
||||
|
||||
@field_validator("metadata_filter", mode="before")
|
||||
@classmethod
|
||||
def _parse_metadata_json(cls, v):
|
||||
if v is None or isinstance(v, dict):
|
||||
return v
|
||||
if isinstance(v, str) and v.strip():
|
||||
try:
|
||||
parsed = json.loads(v)
|
||||
except Exception as e:
|
||||
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||
if not isinstance(parsed, dict):
|
||||
raise ValueError("metadata_filter must be a JSON object")
|
||||
return parsed
|
||||
return None
|
||||
|
||||
|
||||
class TagsListQuery(BaseModel):
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
@ -154,38 +233,36 @@ class TagsRemove(TagsAdd):
|
||||
|
||||
class UploadAssetSpec(BaseModel):
|
||||
"""Upload Asset operation.
|
||||
- tags: ordered; first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
|
||||
|
||||
- tags: optional list; if provided, first is root ('models'|'input'|'output');
|
||||
if root == 'models', second must be a valid category
|
||||
- name: display name
|
||||
- user_metadata: arbitrary JSON object (optional)
|
||||
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
|
||||
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
||||
- mime_type: optional MIME type override
|
||||
- preview_id: optional asset_reference ID for preview
|
||||
|
||||
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
|
||||
and the original extension is preserved when available.
|
||||
Files are stored using the content hash as filename stem.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||
|
||||
tags: list[str] = Field(..., min_length=1)
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
hash: str | None = Field(default=None)
|
||||
mime_type: str | None = Field(default=None)
|
||||
preview_id: str | None = Field(default=None) # references an asset_reference id
|
||||
|
||||
@field_validator("hash", mode="before")
|
||||
@classmethod
|
||||
def _parse_hash(cls, v):
|
||||
if v is None:
|
||||
return None
|
||||
s = str(v).strip().lower()
|
||||
s = str(v).strip()
|
||||
if not s:
|
||||
return None
|
||||
if ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if algo != "blake3":
|
||||
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
|
||||
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
|
||||
raise ValueError("hash digest must be lowercase hex")
|
||||
return f"{algo}:{digest}"
|
||||
return validate_blake3_hash(s)
|
||||
|
||||
@field_validator("tags", mode="before")
|
||||
@classmethod
|
||||
@ -254,11 +331,13 @@ class UploadAssetSpec(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def _validate_order(self):
|
||||
if not self.tags:
|
||||
raise ValueError("tags must be provided and non-empty")
|
||||
raise ValueError("at least one tag is required for uploads")
|
||||
root = self.tags[0]
|
||||
if root not in {"models", "input", "output"}:
|
||||
raise ValueError("first tag must be one of: models, input, output")
|
||||
if root == "models":
|
||||
if len(self.tags) < 2:
|
||||
raise ValueError("models uploads require a category tag as the second tag")
|
||||
raise ValueError(
|
||||
"models uploads require a category tag as the second tag"
|
||||
)
|
||||
return self
|
||||
|
||||
@ -4,7 +4,10 @@ from typing import Any
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||
|
||||
|
||||
class AssetSummary(BaseModel):
|
||||
class Asset(BaseModel):
|
||||
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
|
||||
``id`` here is the AssetReference id, not the content-addressed Asset id."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
@ -12,61 +15,33 @@ class AssetSummary(BaseModel):
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
preview_url: str | None = None
|
||||
created_at: datetime | None = None
|
||||
updated_at: datetime | None = None
|
||||
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
is_immutable: bool = False
|
||||
metadata: dict[str, Any] | None = None
|
||||
job_id: str | None = None
|
||||
prompt_id: str | None = None # deprecated: use job_id
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "updated_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
def _serialize_datetime(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(Asset):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class AssetsList(BaseModel):
|
||||
assets: list[AssetSummary]
|
||||
assets: list[Asset]
|
||||
total: int
|
||||
has_more: bool
|
||||
|
||||
|
||||
class AssetUpdated(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
updated_at: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("updated_at")
|
||||
def _ser_updated(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetDetail(BaseModel):
|
||||
id: str
|
||||
name: str
|
||||
asset_hash: str | None = None
|
||||
size: int | None = None
|
||||
mime_type: str | None = None
|
||||
tags: list[str] = Field(default_factory=list)
|
||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
preview_id: str | None = None
|
||||
created_at: datetime | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
model_config = ConfigDict(from_attributes=True)
|
||||
|
||||
@field_serializer("created_at", "last_access_time")
|
||||
def _ser_dt(self, v: datetime | None, _info):
|
||||
return v.isoformat() if v else None
|
||||
|
||||
|
||||
class AssetCreated(AssetDetail):
|
||||
created_new: bool
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
@ -91,3 +66,7 @@ class TagsRemove(BaseModel):
|
||||
removed: list[str] = Field(default_factory=list)
|
||||
not_present: list[str] = Field(default_factory=list)
|
||||
total_tags: list[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class TagHistogram(BaseModel):
|
||||
tag_counts: dict[str, int]
|
||||
|
||||
185
app/assets/api/upload.py
Normal file
185
app/assets/api/upload.py
Normal file
@ -0,0 +1,185 @@
|
||||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Callable
|
||||
|
||||
from aiohttp import web
|
||||
|
||||
import folder_paths
|
||||
from app.assets.api.schemas_in import ParsedUpload, UploadError
|
||||
from app.assets.helpers import validate_blake3_hash
|
||||
|
||||
|
||||
def normalize_and_validate_hash(s: str) -> str:
|
||||
"""Validate and normalize a hash string.
|
||||
|
||||
Returns canonical 'blake3:<hex>' or raises UploadError.
|
||||
"""
|
||||
try:
|
||||
return validate_blake3_hash(s)
|
||||
except ValueError:
|
||||
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
|
||||
|
||||
|
||||
async def parse_multipart_upload(
|
||||
request: web.Request,
|
||||
check_hash_exists: Callable[[str], bool],
|
||||
) -> ParsedUpload:
|
||||
"""
|
||||
Parse a multipart/form-data upload request.
|
||||
|
||||
Args:
|
||||
request: The aiohttp request
|
||||
check_hash_exists: Callable(hash_str) -> bool to check if a hash exists
|
||||
|
||||
Returns:
|
||||
ParsedUpload with parsed fields and temp file path
|
||||
|
||||
Raises:
|
||||
UploadError: On validation or I/O errors
|
||||
"""
|
||||
if not (request.content_type or "").lower().startswith("multipart/"):
|
||||
raise UploadError(
|
||||
415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads."
|
||||
)
|
||||
|
||||
reader = await request.multipart()
|
||||
|
||||
file_present = False
|
||||
file_client_name: str | None = None
|
||||
tags_raw: list[str] = []
|
||||
provided_name: str | None = None
|
||||
user_metadata_raw: str | None = None
|
||||
provided_hash: str | None = None
|
||||
provided_hash_exists: bool | None = None
|
||||
provided_mime_type: str | None = None
|
||||
provided_preview_id: str | None = None
|
||||
|
||||
file_written = 0
|
||||
tmp_path: str | None = None
|
||||
|
||||
while True:
|
||||
field = await reader.next()
|
||||
if field is None:
|
||||
break
|
||||
|
||||
fname = getattr(field, "name", "") or ""
|
||||
|
||||
if fname == "hash":
|
||||
try:
|
||||
s = ((await field.text()) or "").strip().lower()
|
||||
except Exception:
|
||||
raise UploadError(
|
||||
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
|
||||
)
|
||||
|
||||
if s:
|
||||
provided_hash = normalize_and_validate_hash(s)
|
||||
try:
|
||||
provided_hash_exists = check_hash_exists(provided_hash)
|
||||
except Exception as e:
|
||||
logging.exception(
|
||||
"check_hash_exists failed for hash=%s: %s", provided_hash, e
|
||||
)
|
||||
raise UploadError(
|
||||
500,
|
||||
"HASH_CHECK_FAILED",
|
||||
"Backend error while checking asset hash.",
|
||||
)
|
||||
|
||||
elif fname == "file":
|
||||
file_present = True
|
||||
file_client_name = (field.filename or "").strip()
|
||||
|
||||
if provided_hash and provided_hash_exists is True:
|
||||
# Hash exists - drain file but don't write to disk
|
||||
try:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
raise UploadError(
|
||||
500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file."
|
||||
)
|
||||
continue
|
||||
|
||||
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
|
||||
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
|
||||
os.makedirs(unique_dir, exist_ok=True)
|
||||
tmp_path = os.path.join(unique_dir, ".upload.part")
|
||||
|
||||
try:
|
||||
with open(tmp_path, "wb") as f:
|
||||
while True:
|
||||
chunk = await field.read_chunk(8 * 1024 * 1024)
|
||||
if not chunk:
|
||||
break
|
||||
f.write(chunk)
|
||||
file_written += len(chunk)
|
||||
except Exception:
|
||||
delete_temp_file_if_exists(tmp_path)
|
||||
raise UploadError(
|
||||
500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file."
|
||||
)
|
||||
|
||||
elif fname == "tags":
|
||||
tags_raw.append((await field.text()) or "")
|
||||
elif fname == "name":
|
||||
provided_name = (await field.text()) or None
|
||||
elif fname == "user_metadata":
|
||||
user_metadata_raw = (await field.text()) or None
|
||||
elif fname == "id":
|
||||
raise UploadError(
|
||||
400,
|
||||
"UNSUPPORTED_FIELD",
|
||||
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
|
||||
)
|
||||
elif fname == "mime_type":
|
||||
provided_mime_type = ((await field.text()) or "").strip() or None
|
||||
elif fname == "preview_id":
|
||||
provided_preview_id = ((await field.text()) or "").strip() or None
|
||||
|
||||
if not file_present and not (provided_hash and provided_hash_exists):
|
||||
raise UploadError(
|
||||
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."
|
||||
)
|
||||
|
||||
if (
|
||||
file_present
|
||||
and file_written == 0
|
||||
and not (provided_hash and provided_hash_exists)
|
||||
):
|
||||
delete_temp_file_if_exists(tmp_path)
|
||||
raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
|
||||
|
||||
return ParsedUpload(
|
||||
file_present=file_present,
|
||||
file_written=file_written,
|
||||
file_client_name=file_client_name,
|
||||
tmp_path=tmp_path,
|
||||
tags_raw=tags_raw,
|
||||
provided_name=provided_name,
|
||||
user_metadata_raw=user_metadata_raw,
|
||||
provided_hash=provided_hash,
|
||||
provided_hash_exists=provided_hash_exists,
|
||||
provided_mime_type=provided_mime_type,
|
||||
provided_preview_id=provided_preview_id,
|
||||
)
|
||||
|
||||
|
||||
def delete_temp_file_if_exists(tmp_path: str | None) -> None:
|
||||
"""Safely remove a temp file and its parent directory if empty."""
|
||||
if tmp_path:
|
||||
try:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
except OSError as e:
|
||||
logging.debug("Failed to delete temp file %s: %s", tmp_path, e)
|
||||
try:
|
||||
parent = os.path.dirname(tmp_path)
|
||||
if parent and os.path.isdir(parent):
|
||||
os.rmdir(parent) # only succeeds if empty
|
||||
except OSError:
|
||||
pass
|
||||
@ -1,204 +0,0 @@
|
||||
import os
|
||||
import uuid
|
||||
import sqlalchemy
|
||||
from typing import Iterable
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
|
||||
if not rows:
|
||||
return []
|
||||
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
|
||||
for i in range(0, len(rows), rows_per_stmt):
|
||||
yield rows[i:i + rows_per_stmt]
|
||||
|
||||
def _iter_chunks(seq, n: int):
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i:i + n]
|
||||
|
||||
def _rows_per_stmt(cols: int) -> int:
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def seed_from_paths_batch(
|
||||
session: Session,
|
||||
*,
|
||||
specs: list[dict],
|
||||
owner_id: str = "",
|
||||
) -> dict:
|
||||
"""Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
"""
|
||||
if not specs:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
|
||||
|
||||
now = utcnow()
|
||||
asset_rows: list[dict] = []
|
||||
state_rows: list[dict] = []
|
||||
path_to_asset: dict[str, str] = {}
|
||||
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
|
||||
path_list: list[str] = []
|
||||
|
||||
for sp in specs:
|
||||
ap = os.path.abspath(sp["abs_path"])
|
||||
aid = str(uuid.uuid4())
|
||||
iid = str(uuid.uuid4())
|
||||
path_list.append(ap)
|
||||
path_to_asset[ap] = aid
|
||||
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": aid,
|
||||
"hash": None,
|
||||
"size_bytes": sp["size_bytes"],
|
||||
"mime_type": None,
|
||||
"created_at": now,
|
||||
}
|
||||
)
|
||||
state_rows.append(
|
||||
{
|
||||
"asset_id": aid,
|
||||
"file_path": ap,
|
||||
"mtime_ns": sp["mtime_ns"],
|
||||
}
|
||||
)
|
||||
asset_to_info[aid] = {
|
||||
"id": iid,
|
||||
"owner_id": owner_id,
|
||||
"name": sp["info_name"],
|
||||
"asset_id": aid,
|
||||
"preview_id": None,
|
||||
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
"_tags": sp["tags"],
|
||||
"_filename": sp["fname"],
|
||||
}
|
||||
|
||||
# insert all seed Assets (hash=NULL)
|
||||
ins_asset = sqlite.insert(Asset)
|
||||
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
|
||||
session.execute(ins_asset, chunk)
|
||||
|
||||
# try to claim AssetCacheState (file_path)
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
|
||||
ins_state = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
|
||||
session.execute(ins_state, chunk)
|
||||
|
||||
# Query to find which of our paths won (were actually inserted)
|
||||
winners_by_path: set[str] = set()
|
||||
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetCacheState.file_path)
|
||||
.where(AssetCacheState.file_path.in_(chunk))
|
||||
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
|
||||
)
|
||||
winners_by_path.update(result.scalars().all())
|
||||
|
||||
all_paths_set = set(path_list)
|
||||
losers_by_path = all_paths_set - winners_by_path
|
||||
lost_assets = [path_to_asset[p] for p in losers_by_path]
|
||||
if lost_assets: # losers get their Asset removed
|
||||
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
|
||||
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
|
||||
|
||||
if not winners_by_path:
|
||||
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
|
||||
|
||||
# insert AssetInfo only for winners
|
||||
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
|
||||
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
|
||||
ins_info = (
|
||||
sqlite.insert(AssetInfo)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
|
||||
)
|
||||
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
|
||||
session.execute(ins_info, chunk)
|
||||
|
||||
# Query to find which info rows were actually inserted (by matching our generated IDs)
|
||||
all_info_ids = [row["id"] for row in winner_info_rows]
|
||||
inserted_info_ids: set[str] = set()
|
||||
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
|
||||
result = session.execute(
|
||||
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
|
||||
)
|
||||
inserted_info_ids.update(result.scalars().all())
|
||||
|
||||
# build and insert tag + meta rows for the AssetInfo
|
||||
tag_rows: list[dict] = []
|
||||
meta_rows: list[dict] = []
|
||||
if inserted_info_ids:
|
||||
for row in winner_info_rows:
|
||||
iid = row["id"]
|
||||
if iid not in inserted_info_ids:
|
||||
continue
|
||||
for t in row["_tags"]:
|
||||
tag_rows.append({
|
||||
"asset_info_id": iid,
|
||||
"tag_name": t,
|
||||
"origin": "automatic",
|
||||
"added_at": now,
|
||||
})
|
||||
if row["_filename"]:
|
||||
meta_rows.append(
|
||||
{
|
||||
"asset_info_id": iid,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": row["_filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
|
||||
return {
|
||||
"inserted_infos": len(inserted_info_ids),
|
||||
"won_states": len(winners_by_path),
|
||||
"lost_states": len(losers_by_path),
|
||||
}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
*,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
max_bind_params: int,
|
||||
) -> None:
|
||||
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
|
||||
- tag_rows keys: asset_info_id, tag_name, origin, added_at
|
||||
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_links = (
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
|
||||
session.execute(ins_links, chunk)
|
||||
if meta_rows:
|
||||
ins_meta = (
|
||||
sqlite.insert(AssetInfoMeta)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
|
||||
)
|
||||
)
|
||||
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
|
||||
session.execute(ins_meta, chunk)
|
||||
@ -2,8 +2,8 @@ from __future__ import annotations
|
||||
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import (
|
||||
JSON,
|
||||
BigInteger,
|
||||
@ -16,47 +16,36 @@ from sqlalchemy import (
|
||||
Numeric,
|
||||
String,
|
||||
Text,
|
||||
UniqueConstraint,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
|
||||
|
||||
from app.assets.helpers import utcnow
|
||||
from app.database.models import to_dict, Base
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.database.models import Base
|
||||
|
||||
|
||||
class Asset(Base):
|
||||
__tablename__ = "assets"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
|
||||
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
|
||||
mime_type: Mapped[str | None] = mapped_column(String(255))
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
|
||||
infos: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
references: Mapped[list[AssetReference]] = relationship(
|
||||
"AssetReference",
|
||||
back_populates="asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
|
||||
foreign_keys=lambda: [AssetInfo.asset_id],
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id),
|
||||
foreign_keys=lambda: [AssetReference.asset_id],
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
preview_of: Mapped[list[AssetInfo]] = relationship(
|
||||
"AssetInfo",
|
||||
back_populates="preview_asset",
|
||||
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
|
||||
foreign_keys=lambda: [AssetInfo.preview_id],
|
||||
viewonly=True,
|
||||
)
|
||||
|
||||
cache_states: Mapped[list[AssetCacheState]] = relationship(
|
||||
back_populates="asset",
|
||||
cascade="all, delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
# preview_id on AssetReference is a self-referential FK to asset_references.id
|
||||
|
||||
__table_args__ = (
|
||||
Index("uq_assets_hash", "hash", unique=True),
|
||||
@ -64,108 +53,126 @@ class Asset(Base):
|
||||
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
|
||||
|
||||
|
||||
class AssetCacheState(Base):
|
||||
__tablename__ = "asset_cache_state"
|
||||
class AssetReference(Base):
|
||||
"""Unified model combining file cache state and user-facing metadata.
|
||||
|
||||
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
|
||||
file_path: Mapped[str] = mapped_column(Text, nullable=False)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
Each row represents either:
|
||||
- A filesystem reference (file_path is set) with cache state
|
||||
- An API-created reference (file_path is NULL) without cache state
|
||||
"""
|
||||
|
||||
asset: Mapped[Asset] = relationship(back_populates="cache_states")
|
||||
__tablename__ = "asset_references"
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_cache_state_file_path", "file_path"),
|
||||
Index("ix_asset_cache_state_asset_id", "asset_id"),
|
||||
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
|
||||
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
|
||||
id: Mapped[str] = mapped_column(
|
||||
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
|
||||
)
|
||||
asset_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
return to_dict(self, include_none=include_none)
|
||||
# Cache state fields (from former AssetCacheState)
|
||||
file_path: Mapped[str | None] = mapped_column(Text, nullable=True)
|
||||
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
|
||||
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
|
||||
enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
|
||||
|
||||
|
||||
class AssetInfo(Base):
|
||||
__tablename__ = "assets_info"
|
||||
|
||||
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
|
||||
# Info fields (from former AssetInfo)
|
||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
|
||||
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
|
||||
preview_id: Mapped[str | None] = mapped_column(
|
||||
String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
|
||||
)
|
||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSON(none_as_null=True)
|
||||
)
|
||||
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||
JSON(none_as_null=True), nullable=True, default=None
|
||||
)
|
||||
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
last_access_time: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
deleted_at: Mapped[datetime | None] = mapped_column(
|
||||
DateTime(timezone=False), nullable=True, default=None
|
||||
)
|
||||
|
||||
asset: Mapped[Asset] = relationship(
|
||||
"Asset",
|
||||
back_populates="infos",
|
||||
back_populates="references",
|
||||
foreign_keys=[asset_id],
|
||||
lazy="selectin",
|
||||
)
|
||||
preview_asset: Mapped[Asset | None] = relationship(
|
||||
"Asset",
|
||||
back_populates="preview_of",
|
||||
preview_ref: Mapped[AssetReference | None] = relationship(
|
||||
"AssetReference",
|
||||
foreign_keys=[preview_id],
|
||||
remote_side=lambda: [AssetReference.id],
|
||||
)
|
||||
|
||||
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
|
||||
back_populates="asset_info",
|
||||
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
||||
back_populates="asset_reference",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
)
|
||||
|
||||
tag_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
back_populates="asset_info",
|
||||
tag_links: Mapped[list[AssetReferenceTag]] = relationship(
|
||||
back_populates="asset_reference",
|
||||
cascade="all,delete-orphan",
|
||||
passive_deletes=True,
|
||||
overlaps="tags,asset_infos",
|
||||
overlaps="tags,asset_references",
|
||||
)
|
||||
|
||||
tags: Mapped[list[Tag]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
back_populates="asset_infos",
|
||||
secondary="asset_reference_tags",
|
||||
back_populates="asset_references",
|
||||
lazy="selectin",
|
||||
viewonly=True,
|
||||
overlaps="tag_links,asset_info_links,asset_infos,tag",
|
||||
overlaps="tag_links,asset_reference_links,asset_references,tag",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
|
||||
Index("ix_assets_info_owner_name", "owner_id", "name"),
|
||||
Index("ix_assets_info_owner_id", "owner_id"),
|
||||
Index("ix_assets_info_asset_id", "asset_id"),
|
||||
Index("ix_assets_info_name", "name"),
|
||||
Index("ix_assets_info_created_at", "created_at"),
|
||||
Index("ix_assets_info_last_access_time", "last_access_time"),
|
||||
Index("uq_asset_references_file_path", "file_path", unique=True),
|
||||
Index("ix_asset_references_asset_id", "asset_id"),
|
||||
Index("ix_asset_references_owner_id", "owner_id"),
|
||||
Index("ix_asset_references_name", "name"),
|
||||
Index("ix_asset_references_is_missing", "is_missing"),
|
||||
Index("ix_asset_references_enrichment_level", "enrichment_level"),
|
||||
Index("ix_asset_references_created_at", "created_at"),
|
||||
Index("ix_asset_references_last_access_time", "last_access_time"),
|
||||
Index("ix_asset_references_deleted_at", "deleted_at"),
|
||||
Index("ix_asset_references_preview_id", "preview_id"),
|
||||
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
||||
CheckConstraint(
|
||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||
),
|
||||
CheckConstraint(
|
||||
"enrichment_level >= 0 AND enrichment_level <= 2",
|
||||
name="ck_ar_enrichment_level_range",
|
||||
),
|
||||
)
|
||||
|
||||
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
|
||||
data = to_dict(self, include_none=include_none)
|
||||
data["tags"] = [t.name for t in self.tags]
|
||||
return data
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
|
||||
path_part = f" path={self.file_path!r}" if self.file_path else ""
|
||||
return f"<AssetReference id={self.id} name={self.name!r}{path_part}>"
|
||||
|
||||
|
||||
class AssetInfoMeta(Base):
|
||||
__tablename__ = "asset_info_meta"
|
||||
class AssetReferenceMeta(Base):
|
||||
__tablename__ = "asset_reference_meta"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
asset_reference_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
key: Mapped[str] = mapped_column(String(256), primary_key=True)
|
||||
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
|
||||
@ -175,36 +182,44 @@ class AssetInfoMeta(Base):
|
||||
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
|
||||
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
|
||||
asset_reference: Mapped[AssetReference] = relationship(
|
||||
back_populates="metadata_entries"
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_meta_key", "key"),
|
||||
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
|
||||
Index("ix_asset_reference_meta_key", "key"),
|
||||
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
||||
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
||||
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
||||
CheckConstraint(
|
||||
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||
name="has_value",
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class AssetInfoTag(Base):
|
||||
__tablename__ = "asset_info_tags"
|
||||
class AssetReferenceTag(Base):
|
||||
__tablename__ = "asset_reference_tags"
|
||||
|
||||
asset_info_id: Mapped[str] = mapped_column(
|
||||
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
|
||||
asset_reference_id: Mapped[str] = mapped_column(
|
||||
String(36),
|
||||
ForeignKey("asset_references.id", ondelete="CASCADE"),
|
||||
primary_key=True,
|
||||
)
|
||||
tag_name: Mapped[str] = mapped_column(
|
||||
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
|
||||
)
|
||||
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
|
||||
added_at: Mapped[datetime] = mapped_column(
|
||||
DateTime(timezone=False), nullable=False, default=utcnow
|
||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||
)
|
||||
|
||||
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
|
||||
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
|
||||
asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links")
|
||||
tag: Mapped[Tag] = relationship(back_populates="asset_reference_links")
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_asset_info_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
|
||||
Index("ix_asset_reference_tags_tag_name", "tag_name"),
|
||||
Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"),
|
||||
)
|
||||
|
||||
|
||||
@ -214,20 +229,18 @@ class Tag(Base):
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
|
||||
asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
|
||||
back_populates="tag",
|
||||
overlaps="asset_infos,tags",
|
||||
overlaps="asset_references,tags",
|
||||
)
|
||||
asset_infos: Mapped[list[AssetInfo]] = relationship(
|
||||
secondary="asset_info_tags",
|
||||
asset_references: Mapped[list[AssetReference]] = relationship(
|
||||
secondary="asset_reference_tags",
|
||||
back_populates="tags",
|
||||
viewonly=True,
|
||||
overlaps="asset_info_links,tag_links,tags,asset_info",
|
||||
overlaps="asset_reference_links,tag_links,tags,asset_reference",
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index("ix_tags_tag_type", "tag_type"),
|
||||
)
|
||||
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
|
||||
@ -1,976 +0,0 @@
|
||||
import os
|
||||
import logging
|
||||
import sqlalchemy as sa
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
from typing import Iterable, Any
|
||||
from sqlalchemy import select, delete, exists, func
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session, contains_eager, noload
|
||||
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
|
||||
from app.assets.helpers import (
|
||||
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
|
||||
)
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetInfo.owner_id == ""
|
||||
return AssetInfo.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
|
||||
"""
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetInfoTag.asset_info_id == AssetInfo.id)
|
||||
& (AssetInfoTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_info_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
no_row_for_key = sa.not_(
|
||||
sa.exists().where(
|
||||
AssetInfoMeta.asset_info_id == AssetInfo.id,
|
||||
AssetInfoMeta.key == key,
|
||||
)
|
||||
)
|
||||
null_row = _exists_for_pred(
|
||||
key,
|
||||
AssetInfoMeta.val_json.is_(None),
|
||||
AssetInfoMeta.val_str.is_(None),
|
||||
AssetInfoMeta.val_num.is_(None),
|
||||
AssetInfoMeta.val_bool.is_(None),
|
||||
)
|
||||
return sa.or_(no_row_for_key, null_row)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float)):
|
||||
from decimal import Decimal
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
|
||||
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
def asset_info_exists_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> bool:
|
||||
q = (
|
||||
select(sa.literal(True))
|
||||
.select_from(AssetInfo)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.limit(1)
|
||||
)
|
||||
return (session.execute(q)).first() is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
|
||||
|
||||
def get_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
) -> AssetInfo | None:
|
||||
return session.get(AssetInfo, asset_info_id)
|
||||
|
||||
|
||||
def list_asset_infos_page(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
|
||||
base = (
|
||||
select(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||
base = apply_metadata_filter(base, metadata_filter)
|
||||
|
||||
sort = (sort or "created_at").lower()
|
||||
order = (order or "desc").lower()
|
||||
sort_map = {
|
||||
"name": AssetInfo.name,
|
||||
"created_at": AssetInfo.created_at,
|
||||
"updated_at": AssetInfo.updated_at,
|
||||
"last_access_time": AssetInfo.last_access_time,
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetInfo.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
.select_from(AssetInfo)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
)
|
||||
if name_contains:
|
||||
escaped, esc = escape_like_prefix(name_contains)
|
||||
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
|
||||
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||
|
||||
total = int((session.execute(count_stmt)).scalar_one() or 0)
|
||||
|
||||
infos = (session.execute(base)).unique().scalars().all()
|
||||
|
||||
id_list: list[str] = [i.id for i in infos]
|
||||
tag_map: dict[str, list[str]] = defaultdict(list)
|
||||
if id_list:
|
||||
rows = session.execute(
|
||||
select(AssetInfoTag.asset_info_id, Tag.name)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name)
|
||||
.where(AssetInfoTag.asset_info_id.in_(id_list))
|
||||
.order_by(AssetInfoTag.added_at)
|
||||
)
|
||||
for aid, tag_name in rows.all():
|
||||
tag_map[aid].append(tag_name)
|
||||
|
||||
return infos, tag_map, total
|
||||
|
||||
|
||||
def fetch_asset_info_asset_and_tags(
|
||||
session: Session,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset, list[str]] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset, Tag.name)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
|
||||
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.order_by(Tag.name.asc())
|
||||
)
|
||||
|
||||
rows = (session.execute(stmt)).all()
|
||||
if not rows:
|
||||
return None
|
||||
|
||||
first_info, first_asset, _ = rows[0]
|
||||
tags: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for _info, _asset, tag_name in rows:
|
||||
if tag_name and tag_name not in seen:
|
||||
seen.add(tag_name)
|
||||
tags.append(tag_name)
|
||||
return first_info, first_asset, tags
|
||||
|
||||
|
||||
def fetch_asset_info_and_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[AssetInfo, Asset] | None:
|
||||
stmt = (
|
||||
select(AssetInfo, Asset)
|
||||
.join(Asset, Asset.id == AssetInfo.asset_id)
|
||||
.where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
.options(noload(AssetInfo.tags))
|
||||
)
|
||||
row = session.execute(stmt)
|
||||
pair = row.first()
|
||||
if not pair:
|
||||
return None
|
||||
return pair[0], pair[1]
|
||||
|
||||
def list_cache_states_by_asset_id(
|
||||
session: Session, *, asset_id: str
|
||||
) -> Sequence[AssetCacheState]:
|
||||
return (
|
||||
session.execute(
|
||||
select(AssetCacheState)
|
||||
.where(AssetCacheState.asset_id == asset_id)
|
||||
.order_by(AssetCacheState.id.asc())
|
||||
)
|
||||
).scalars().all()
|
||||
|
||||
|
||||
def touch_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
ts: datetime | None = None,
|
||||
only_if_newer: bool = True,
|
||||
) -> None:
|
||||
ts = ts or utcnow()
|
||||
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
|
||||
if only_if_newer:
|
||||
stmt = stmt.where(
|
||||
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
|
||||
)
|
||||
session.execute(stmt.values(last_access_time=ts))
|
||||
|
||||
|
||||
def create_asset_info_for_existing_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AssetInfo:
|
||||
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
|
||||
now = utcnow()
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"Unknown asset hash {asset_hash}")
|
||||
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
preview_id=None,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
try:
|
||||
with session.begin_nested():
|
||||
session.add(info)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
existing = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.options(noload(AssetInfo.tags))
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == name,
|
||||
AssetInfo.owner_id == owner_id,
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalars().first()
|
||||
if not existing:
|
||||
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
|
||||
return existing
|
||||
|
||||
# metadata["filename"] hack
|
||||
new_meta = dict(user_metadata or {})
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
if new_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=info.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
return info
|
||||
|
||||
|
||||
def set_asset_info_tags(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> dict:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
|
||||
).all()
|
||||
)
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all([
|
||||
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
|
||||
for t in to_add
|
||||
])
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return {"added": to_add, "removed": to_remove, "total": desired}
|
||||
|
||||
|
||||
def replace_asset_info_metadata_projection(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
user_metadata: dict | None = None,
|
||||
) -> None:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info.user_metadata = user_metadata or {}
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
|
||||
session.flush()
|
||||
|
||||
if not user_metadata:
|
||||
return
|
||||
|
||||
rows: list[AssetInfoMeta] = []
|
||||
for k, v in user_metadata.items():
|
||||
for r in project_kv(k, v):
|
||||
rows.append(
|
||||
AssetInfoMeta(
|
||||
asset_info_id=asset_info_id,
|
||||
key=r["key"],
|
||||
ordinal=int(r["ordinal"]),
|
||||
val_str=r.get("val_str"),
|
||||
val_num=r.get("val_num"),
|
||||
val_bool=r.get("val_bool"),
|
||||
val_json=r.get("val_json"),
|
||||
)
|
||||
)
|
||||
if rows:
|
||||
session.add_all(rows)
|
||||
session.flush()
|
||||
|
||||
|
||||
def ingest_fs_asset(
|
||||
session: Session,
|
||||
*,
|
||||
asset_hash: str,
|
||||
abs_path: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> dict:
|
||||
"""
|
||||
Idempotently upsert:
|
||||
- Asset by content hash (create if missing)
|
||||
- AssetCacheState(file_path) pointing to asset_id
|
||||
- Optionally AssetInfo + tag links and metadata projection
|
||||
Returns flags and ids.
|
||||
"""
|
||||
locator = os.path.abspath(abs_path)
|
||||
now = utcnow()
|
||||
|
||||
if preview_id:
|
||||
if not session.get(Asset, preview_id):
|
||||
preview_id = None
|
||||
|
||||
out: dict[str, Any] = {
|
||||
"asset_created": False,
|
||||
"asset_updated": False,
|
||||
"state_created": False,
|
||||
"state_updated": False,
|
||||
"asset_info_id": None,
|
||||
}
|
||||
|
||||
# 1) Asset by hash
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
vals = {
|
||||
"hash": asset_hash,
|
||||
"size_bytes": int(size_bytes),
|
||||
"mime_type": mime_type,
|
||||
"created_at": now,
|
||||
}
|
||||
res = session.execute(
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["asset_created"] = True
|
||||
asset = (
|
||||
session.execute(
|
||||
select(Asset).where(Asset.hash == asset_hash).limit(1)
|
||||
)
|
||||
).scalars().first()
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
else:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and asset.mime_type != mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
out["asset_updated"] = True
|
||||
|
||||
# 2) AssetCacheState upsert by file_path (unique)
|
||||
vals = {
|
||||
"asset_id": asset.id,
|
||||
"file_path": locator,
|
||||
"mtime_ns": int(mtime_ns),
|
||||
}
|
||||
ins = (
|
||||
sqlite.insert(AssetCacheState)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
|
||||
)
|
||||
|
||||
res = session.execute(ins)
|
||||
if int(res.rowcount or 0) > 0:
|
||||
out["state_created"] = True
|
||||
else:
|
||||
upd = (
|
||||
sa.update(AssetCacheState)
|
||||
.where(AssetCacheState.file_path == locator)
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetCacheState.asset_id != asset.id,
|
||||
AssetCacheState.mtime_ns.is_(None),
|
||||
AssetCacheState.mtime_ns != int(mtime_ns),
|
||||
)
|
||||
)
|
||||
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
|
||||
)
|
||||
res2 = session.execute(upd)
|
||||
if int(res2.rowcount or 0) > 0:
|
||||
out["state_updated"] = True
|
||||
|
||||
# 3) Optional AssetInfo + tags + metadata
|
||||
if info_name:
|
||||
try:
|
||||
with session.begin_nested():
|
||||
info = AssetInfo(
|
||||
owner_id=owner_id,
|
||||
name=info_name,
|
||||
asset_id=asset.id,
|
||||
preview_id=preview_id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(info)
|
||||
session.flush()
|
||||
out["asset_info_id"] = info.id
|
||||
except IntegrityError:
|
||||
pass
|
||||
|
||||
existing_info = (
|
||||
session.execute(
|
||||
select(AssetInfo)
|
||||
.where(
|
||||
AssetInfo.asset_id == asset.id,
|
||||
AssetInfo.name == info_name,
|
||||
(AssetInfo.owner_id == owner_id),
|
||||
)
|
||||
.limit(1)
|
||||
)
|
||||
).unique().scalar_one_or_none()
|
||||
if not existing_info:
|
||||
raise RuntimeError("Failed to update or insert AssetInfo.")
|
||||
|
||||
if preview_id and existing_info.preview_id != preview_id:
|
||||
existing_info.preview_id = preview_id
|
||||
|
||||
existing_info.updated_at = now
|
||||
if existing_info.last_access_time < now:
|
||||
existing_info.last_access_time = now
|
||||
session.flush()
|
||||
out["asset_info_id"] = existing_info.id
|
||||
|
||||
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
if norm and out["asset_info_id"] is not None:
|
||||
if not require_existing_tags:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
existing_tag_names = set(
|
||||
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
|
||||
)
|
||||
missing = [t for t in norm if t not in existing_tag_names]
|
||||
if missing and require_existing_tags:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
existing_links = set(
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
|
||||
)
|
||||
).all()
|
||||
)
|
||||
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
|
||||
if to_add:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=out["asset_info_id"],
|
||||
tag_name=t,
|
||||
origin=tag_origin,
|
||||
added_at=now,
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
# metadata["filename"] hack
|
||||
if out["asset_info_id"] is not None:
|
||||
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
|
||||
computed_filename = compute_relative_filename(primary_path) if primary_path else None
|
||||
|
||||
current_meta = existing_info.user_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
if user_metadata is not None:
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
replace_asset_info_metadata_projection(
|
||||
session,
|
||||
asset_info_id=out["asset_info_id"],
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
return out
|
||||
|
||||
|
||||
def update_asset_info_full(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
tag_origin: str = "manual",
|
||||
asset_info_row: Any = None,
|
||||
) -> AssetInfo:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
else:
|
||||
info = asset_info_row
|
||||
|
||||
touched = False
|
||||
if name is not None and name != info.name:
|
||||
info.name = name
|
||||
touched = True
|
||||
|
||||
computed_filename = None
|
||||
try:
|
||||
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
|
||||
if p:
|
||||
computed_filename = compute_relative_filename(p)
|
||||
except Exception:
|
||||
computed_filename = None
|
||||
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
else:
|
||||
if computed_filename:
|
||||
current_meta = info.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
new_meta["filename"] = computed_filename
|
||||
replace_asset_info_metadata_projection(
|
||||
session, asset_info_id=asset_info_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_asset_info_tags(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
|
||||
return info
|
||||
|
||||
|
||||
def delete_asset_info_by_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str,
|
||||
) -> bool:
|
||||
stmt = sa.delete(AssetInfo).where(
|
||||
AssetInfo.id == asset_info_id,
|
||||
visible_owner_clause(owner_id),
|
||||
)
|
||||
return int((session.execute(stmt)).rowcount or 0) > 0
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetInfoTag.tag_name.label("tag_name"),
|
||||
func.count(AssetInfoTag.asset_info_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetInfoTag)
|
||||
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
|
||||
.where(visible_owner_clause(owner_id))
|
||||
.group_by(AssetInfoTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_like_prefix(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
total_q = total_q.where(
|
||||
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
|
||||
)
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def add_tags_to_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
asset_info_row: Any = None,
|
||||
) -> dict:
|
||||
if not asset_info_row:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"added": [], "already_present": [], "total_tags": total}
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetInfoTag(
|
||||
asset_info_id=asset_info_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=utcnow(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
|
||||
return {
|
||||
"added": sorted(((after - current) & want)),
|
||||
"already_present": sorted(want & current),
|
||||
"total_tags": sorted(after),
|
||||
}
|
||||
|
||||
|
||||
def remove_tags_from_asset_info(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> dict:
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": [], "not_present": [], "total_tags": total}
|
||||
|
||||
existing = {
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
|
||||
)
|
||||
).all()
|
||||
}
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetInfoTag)
|
||||
.where(
|
||||
AssetInfoTag.asset_info_id == asset_info_id,
|
||||
AssetInfoTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def set_asset_info_preview(
|
||||
session: Session,
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
) -> None:
|
||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||
info = session.get(AssetInfo, asset_info_id)
|
||||
if not info:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
if preview_asset_id is None:
|
||||
info.preview_id = None
|
||||
else:
|
||||
# validate preview asset exists
|
||||
if not session.get(Asset, preview_asset_id):
|
||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
||||
info.preview_id = preview_asset_id
|
||||
|
||||
info.updated_at = utcnow()
|
||||
session.flush()
|
||||
133
app/assets/database/queries/__init__.py
Normal file
133
app/assets/database/queries/__init__.py
Normal file
@ -0,0 +1,133 @@
|
||||
from app.assets.database.queries.asset import (
|
||||
asset_exists_by_hash,
|
||||
bulk_insert_assets,
|
||||
get_asset_by_hash,
|
||||
get_existing_asset_ids,
|
||||
reassign_asset_references,
|
||||
update_asset_hash_and_mime,
|
||||
upsert_asset,
|
||||
)
|
||||
from app.assets.database.queries.asset_reference import (
|
||||
CacheStateRow,
|
||||
UnenrichedReferenceRow,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_update_enrichment_level,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
convert_metadata_to_rows,
|
||||
delete_assets_by_ids,
|
||||
delete_orphaned_seed_asset,
|
||||
delete_reference_by_id,
|
||||
delete_references_by_ids,
|
||||
fetch_reference_and_asset,
|
||||
fetch_reference_asset_and_tags,
|
||||
get_or_create_reference,
|
||||
get_reference_by_file_path,
|
||||
get_reference_by_id,
|
||||
get_reference_with_owner_check,
|
||||
get_reference_ids_by_ids,
|
||||
get_references_by_paths_and_asset_ids,
|
||||
get_references_for_prefixes,
|
||||
get_unenriched_references,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
insert_reference,
|
||||
list_all_file_paths_by_asset_id,
|
||||
list_references_by_asset_id,
|
||||
list_references_page,
|
||||
mark_references_missing_outside_prefixes,
|
||||
rebuild_metadata_projection,
|
||||
reference_exists,
|
||||
reference_exists_for_asset_id,
|
||||
restore_references_by_paths,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
set_reference_system_metadata,
|
||||
soft_delete_reference_by_id,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_is_missing_by_asset_id,
|
||||
update_reference_timestamps,
|
||||
update_reference_updated_at,
|
||||
upsert_reference,
|
||||
)
|
||||
from app.assets.database.queries.tags import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
SetTagsResult,
|
||||
add_missing_tag_for_asset_id,
|
||||
add_tags_to_reference,
|
||||
bulk_insert_tags_and_meta,
|
||||
ensure_tags_exist,
|
||||
get_reference_tags,
|
||||
list_tag_counts_for_filtered_assets,
|
||||
list_tags_with_usage,
|
||||
remove_missing_tag_for_asset_id,
|
||||
remove_tags_from_reference,
|
||||
set_reference_tags,
|
||||
validate_tags_exist,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AddTagsResult",
|
||||
"CacheStateRow",
|
||||
"RemoveTagsResult",
|
||||
"SetTagsResult",
|
||||
"UnenrichedReferenceRow",
|
||||
"add_missing_tag_for_asset_id",
|
||||
"add_tags_to_reference",
|
||||
"asset_exists_by_hash",
|
||||
"bulk_insert_assets",
|
||||
"bulk_insert_references_ignore_conflicts",
|
||||
"bulk_insert_tags_and_meta",
|
||||
"bulk_update_enrichment_level",
|
||||
"bulk_update_is_missing",
|
||||
"bulk_update_needs_verify",
|
||||
"convert_metadata_to_rows",
|
||||
"delete_assets_by_ids",
|
||||
"delete_orphaned_seed_asset",
|
||||
"delete_reference_by_id",
|
||||
"delete_references_by_ids",
|
||||
"ensure_tags_exist",
|
||||
"fetch_reference_and_asset",
|
||||
"fetch_reference_asset_and_tags",
|
||||
"get_asset_by_hash",
|
||||
"get_existing_asset_ids",
|
||||
"get_or_create_reference",
|
||||
"get_reference_by_file_path",
|
||||
"get_reference_by_id",
|
||||
"get_reference_with_owner_check",
|
||||
"get_reference_ids_by_ids",
|
||||
"get_reference_tags",
|
||||
"get_references_by_paths_and_asset_ids",
|
||||
"get_references_for_prefixes",
|
||||
"get_unenriched_references",
|
||||
"get_unreferenced_unhashed_asset_ids",
|
||||
"insert_reference",
|
||||
"list_all_file_paths_by_asset_id",
|
||||
"list_references_by_asset_id",
|
||||
"list_references_page",
|
||||
"list_tag_counts_for_filtered_assets",
|
||||
"list_tags_with_usage",
|
||||
"mark_references_missing_outside_prefixes",
|
||||
"reassign_asset_references",
|
||||
"rebuild_metadata_projection",
|
||||
"reference_exists",
|
||||
"reference_exists_for_asset_id",
|
||||
"remove_missing_tag_for_asset_id",
|
||||
"remove_tags_from_reference",
|
||||
"restore_references_by_paths",
|
||||
"set_reference_metadata",
|
||||
"set_reference_preview",
|
||||
"set_reference_system_metadata",
|
||||
"soft_delete_reference_by_id",
|
||||
"set_reference_tags",
|
||||
"update_asset_hash_and_mime",
|
||||
"update_is_missing_by_asset_id",
|
||||
"update_reference_access_time",
|
||||
"update_reference_name",
|
||||
"update_reference_timestamps",
|
||||
"update_reference_updated_at",
|
||||
"upsert_asset",
|
||||
"upsert_reference",
|
||||
"validate_tags_exist",
|
||||
]
|
||||
140
app/assets/database/queries/asset.py
Normal file
140
app/assets/database/queries/asset.py
Normal file
@ -0,0 +1,140 @@
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks
|
||||
|
||||
|
||||
def asset_exists_by_hash(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
row = (
|
||||
session.execute(
|
||||
select(sa.literal(True))
|
||||
.select_from(Asset)
|
||||
.where(Asset.hash == asset_hash)
|
||||
.limit(1)
|
||||
)
|
||||
).first()
|
||||
return row is not None
|
||||
|
||||
|
||||
def get_asset_by_hash(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
) -> Asset | None:
|
||||
return (
|
||||
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
|
||||
|
||||
def upsert_asset(
|
||||
session: Session,
|
||||
asset_hash: str,
|
||||
size_bytes: int,
|
||||
mime_type: str | None = None,
|
||||
) -> tuple[Asset, bool, bool]:
|
||||
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
|
||||
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
|
||||
if mime_type:
|
||||
vals["mime_type"] = mime_type
|
||||
|
||||
ins = (
|
||||
sqlite.insert(Asset)
|
||||
.values(**vals)
|
||||
.on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
)
|
||||
res = session.execute(ins)
|
||||
created = int(res.rowcount or 0) > 0
|
||||
|
||||
asset = (
|
||||
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
|
||||
.scalars()
|
||||
.first()
|
||||
)
|
||||
if not asset:
|
||||
raise RuntimeError("Asset row not found after upsert.")
|
||||
|
||||
updated = False
|
||||
if not created:
|
||||
changed = False
|
||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||
asset.size_bytes = int(size_bytes)
|
||||
changed = True
|
||||
if mime_type and not asset.mime_type:
|
||||
asset.mime_type = mime_type
|
||||
changed = True
|
||||
if changed:
|
||||
updated = True
|
||||
|
||||
return asset, created, updated
|
||||
|
||||
|
||||
def bulk_insert_assets(
|
||||
session: Session,
|
||||
rows: list[dict],
|
||||
) -> None:
|
||||
"""Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash."""
|
||||
if not rows:
|
||||
return
|
||||
ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash])
|
||||
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
|
||||
session.execute(ins, chunk)
|
||||
|
||||
|
||||
def get_existing_asset_ids(
|
||||
session: Session,
|
||||
asset_ids: list[str],
|
||||
) -> set[str]:
|
||||
"""Return the subset of asset_ids that exist in the database."""
|
||||
if not asset_ids:
|
||||
return set()
|
||||
found: set[str] = set()
|
||||
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
|
||||
rows = session.execute(
|
||||
select(Asset.id).where(Asset.id.in_(chunk))
|
||||
).fetchall()
|
||||
found.update(row[0] for row in rows)
|
||||
return found
|
||||
|
||||
|
||||
def update_asset_hash_and_mime(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
asset_hash: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
) -> bool:
|
||||
"""Update asset hash and/or mime_type. Returns True if asset was found."""
|
||||
asset = session.get(Asset, asset_id)
|
||||
if not asset:
|
||||
return False
|
||||
if asset_hash is not None:
|
||||
asset.hash = asset_hash
|
||||
if mime_type is not None and not asset.mime_type:
|
||||
asset.mime_type = mime_type
|
||||
return True
|
||||
|
||||
|
||||
def reassign_asset_references(
|
||||
session: Session,
|
||||
from_asset_id: str,
|
||||
to_asset_id: str,
|
||||
reference_id: str,
|
||||
) -> None:
|
||||
"""Reassign a reference from one asset to another.
|
||||
|
||||
Used when merging a stub asset into an existing asset with the same hash.
|
||||
"""
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if ref and ref.asset_id == from_asset_id:
|
||||
ref.asset_id = to_asset_id
|
||||
|
||||
session.flush()
|
||||
1028
app/assets/database/queries/asset_reference.py
Normal file
1028
app/assets/database/queries/asset_reference.py
Normal file
File diff suppressed because it is too large
Load Diff
127
app/assets/database/queries/common.py
Normal file
127
app/assets/database/queries/common.py
Normal file
@ -0,0 +1,127 @@
|
||||
"""Shared utilities for database query modules."""
|
||||
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import exists
|
||||
|
||||
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
|
||||
from app.assets.helpers import escape_sql_like_string, normalize_tags
|
||||
|
||||
MAX_BIND_PARAMS = 800
|
||||
|
||||
|
||||
def calculate_rows_per_statement(cols: int) -> int:
|
||||
"""Calculate how many rows can fit in one statement given column count."""
|
||||
return max(1, MAX_BIND_PARAMS // max(1, cols))
|
||||
|
||||
|
||||
def iter_chunks(seq, n: int):
|
||||
"""Yield successive n-sized chunks from seq."""
|
||||
for i in range(0, len(seq), n):
|
||||
yield seq[i : i + n]
|
||||
|
||||
|
||||
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
|
||||
"""Yield chunks of rows sized to fit within bind param limits."""
|
||||
if not rows:
|
||||
return
|
||||
yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row))
|
||||
|
||||
|
||||
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
|
||||
"""Build owner visibility predicate for reads.
|
||||
|
||||
Owner-less rows are visible to everyone.
|
||||
"""
|
||||
owner_id = (owner_id or "").strip()
|
||||
if owner_id == "":
|
||||
return AssetReference.owner_id == ""
|
||||
return AssetReference.owner_id.in_(["", owner_id])
|
||||
|
||||
|
||||
def build_prefix_like_conditions(
|
||||
prefixes: list[str],
|
||||
) -> list[sa.sql.ColumnElement]:
|
||||
"""Build LIKE conditions for matching file paths under directory prefixes."""
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_sql_like_string(base)
|
||||
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
||||
return conds
|
||||
|
||||
|
||||
def apply_tag_filters(
|
||||
stmt: sa.sql.Select,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||
include_tags = normalize_tags(include_tags)
|
||||
exclude_tags = normalize_tags(exclude_tags)
|
||||
|
||||
if include_tags:
|
||||
for tag_name in include_tags:
|
||||
stmt = stmt.where(
|
||||
exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name == tag_name)
|
||||
)
|
||||
)
|
||||
|
||||
if exclude_tags:
|
||||
stmt = stmt.where(
|
||||
~exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
||||
)
|
||||
)
|
||||
return stmt
|
||||
|
||||
|
||||
def apply_metadata_filter(
|
||||
stmt: sa.sql.Select,
|
||||
metadata_filter: dict | None = None,
|
||||
) -> sa.sql.Select:
|
||||
"""Apply filters using asset_reference_meta projection table."""
|
||||
if not metadata_filter:
|
||||
return stmt
|
||||
|
||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||
return sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
*preds,
|
||||
)
|
||||
|
||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||
if value is None:
|
||||
return sa.not_(
|
||||
sa.exists().where(
|
||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||
AssetReferenceMeta.key == key,
|
||||
)
|
||||
)
|
||||
|
||||
if isinstance(value, bool):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
||||
if isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
||||
if isinstance(value, str):
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
||||
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
||||
|
||||
for k, v in metadata_filter.items():
|
||||
if isinstance(v, list):
|
||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||
if ors:
|
||||
stmt = stmt.where(sa.or_(*ors))
|
||||
else:
|
||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||
return stmt
|
||||
418
app/assets/database/queries/tags.py
Normal file
418
app/assets/database/queries/tags.py
Normal file
@ -0,0 +1,418 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Iterable, Sequence
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import delete, func, select
|
||||
from sqlalchemy.dialects import sqlite
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import (
|
||||
Asset,
|
||||
AssetReference,
|
||||
AssetReferenceMeta,
|
||||
AssetReferenceTag,
|
||||
Tag,
|
||||
)
|
||||
from app.assets.database.queries.common import (
|
||||
apply_metadata_filter,
|
||||
apply_tag_filters,
|
||||
build_visible_owner_clause,
|
||||
iter_row_chunks,
|
||||
)
|
||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AddTagsResult:
|
||||
added: list[str]
|
||||
already_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RemoveTagsResult:
|
||||
removed: list[str]
|
||||
not_present: list[str]
|
||||
total_tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class SetTagsResult:
|
||||
added: list[str]
|
||||
removed: list[str]
|
||||
total: list[str]
|
||||
|
||||
|
||||
def validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
"""Raise ValueError if any of the given tag names do not exist."""
|
||||
existing_tag_names = set(
|
||||
name
|
||||
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
|
||||
)
|
||||
missing = [t for t in tags if t not in existing_tag_names]
|
||||
if missing:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
|
||||
def ensure_tags_exist(
|
||||
session: Session, names: Iterable[str], tag_type: str = "user"
|
||||
) -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
session.execute(ins)
|
||||
|
||||
|
||||
def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
||||
return [
|
||||
tag_name
|
||||
for (tag_name,) in (
|
||||
session.execute(
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.where(AssetReferenceTag.asset_reference_id == reference_id)
|
||||
.order_by(AssetReferenceTag.tag_name.asc())
|
||||
)
|
||||
).all()
|
||||
]
|
||||
|
||||
|
||||
def set_reference_tags(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
) -> SetTagsResult:
|
||||
desired = normalize_tags(tags)
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
to_add = [t for t in desired if t not in current]
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetReferenceTag).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id,
|
||||
AssetReferenceTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
|
||||
|
||||
|
||||
def add_tags_to_reference(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
tags: Sequence[str],
|
||||
origin: str = "manual",
|
||||
create_if_missing: bool = True,
|
||||
reference_row: AssetReference | None = None,
|
||||
) -> AddTagsResult:
|
||||
if not reference_row:
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return AddTagsResult(added=[], already_present=[], total_tags=total)
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
want = set(norm)
|
||||
to_add = sorted(want - current)
|
||||
|
||||
if to_add:
|
||||
with session.begin_nested() as nested:
|
||||
try:
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
asset_reference_id=reference_id,
|
||||
tag_name=t,
|
||||
origin=origin,
|
||||
added_at=get_utc_now(),
|
||||
)
|
||||
for t in to_add
|
||||
]
|
||||
)
|
||||
session.flush()
|
||||
except IntegrityError:
|
||||
nested.rollback()
|
||||
|
||||
after = set(get_reference_tags(session, reference_id=reference_id))
|
||||
return AddTagsResult(
|
||||
added=sorted(((after - current) & want)),
|
||||
already_present=sorted(want & current),
|
||||
total_tags=sorted(after),
|
||||
)
|
||||
|
||||
|
||||
def remove_tags_from_reference(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
tags: Sequence[str],
|
||||
) -> RemoveTagsResult:
|
||||
ref = session.get(AssetReference, reference_id)
|
||||
if not ref:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
norm = normalize_tags(tags)
|
||||
if not norm:
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return RemoveTagsResult(removed=[], not_present=[], total_tags=total)
|
||||
|
||||
existing = set(get_reference_tags(session, reference_id))
|
||||
|
||||
to_remove = sorted(set(t for t in norm if t in existing))
|
||||
not_present = sorted(set(t for t in norm if t not in existing))
|
||||
|
||||
if to_remove:
|
||||
session.execute(
|
||||
delete(AssetReferenceTag).where(
|
||||
AssetReferenceTag.asset_reference_id == reference_id,
|
||||
AssetReferenceTag.tag_name.in_(to_remove),
|
||||
)
|
||||
)
|
||||
session.flush()
|
||||
|
||||
total = get_reference_tags(session, reference_id=reference_id)
|
||||
return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total)
|
||||
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sa.select(
|
||||
AssetReference.id.label("asset_reference_id"),
|
||||
sa.literal("missing").label("tag_name"),
|
||||
sa.literal(origin).label("origin"),
|
||||
sa.literal(get_utc_now()).label("added_at"),
|
||||
)
|
||||
.where(AssetReference.asset_id == asset_id)
|
||||
.where(
|
||||
sa.not_(
|
||||
sa.exists().where(
|
||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||
& (AssetReferenceTag.tag_name == "missing")
|
||||
)
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetReferenceTag)
|
||||
.from_select(
|
||||
["asset_reference_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
AssetReferenceTag.asset_reference_id,
|
||||
AssetReferenceTag.tag_name,
|
||||
]
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sa.delete(AssetReferenceTag).where(
|
||||
AssetReferenceTag.asset_reference_id.in_(
|
||||
sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id)
|
||||
),
|
||||
AssetReferenceTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def list_tags_with_usage(
|
||||
session: Session,
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
include_zero: bool = True,
|
||||
order: str = "count_desc",
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[tuple[str, str, int]], int]:
|
||||
counts_sq = (
|
||||
select(
|
||||
AssetReferenceTag.tag_name.label("tag_name"),
|
||||
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
|
||||
)
|
||||
.select_from(AssetReferenceTag)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetReference.is_missing == False, # noqa: E712
|
||||
AssetReferenceTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
|
||||
)
|
||||
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
q = q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
|
||||
if not include_zero:
|
||||
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
|
||||
|
||||
if order == "name_asc":
|
||||
q = q.order_by(Tag.name.asc())
|
||||
else:
|
||||
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
|
||||
|
||||
total_q = select(func.count()).select_from(Tag)
|
||||
if prefix:
|
||||
escaped, esc = escape_sql_like_string(prefix.strip().lower())
|
||||
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
|
||||
if not include_zero:
|
||||
visible_tags_sq = (
|
||||
select(AssetReferenceTag.tag_name)
|
||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(
|
||||
sa.or_(
|
||||
AssetReference.is_missing == False, # noqa: E712
|
||||
AssetReferenceTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
)
|
||||
total_q = total_q.where(Tag.name.in_(visible_tags_sq))
|
||||
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
def list_tag_counts_for_filtered_assets(
|
||||
session: Session,
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 100,
|
||||
) -> dict[str, int]:
|
||||
"""Return tag counts for assets matching the given filters.
|
||||
|
||||
Uses the same filtering logic as list_references_page but returns
|
||||
{tag_name: count} instead of paginated references.
|
||||
"""
|
||||
# Build a subquery of matching reference IDs
|
||||
ref_sq = (
|
||||
select(AssetReference.id)
|
||||
.join(Asset, Asset.id == AssetReference.asset_id)
|
||||
.where(build_visible_owner_clause(owner_id))
|
||||
.where(AssetReference.is_missing == False) # noqa: E712
|
||||
.where(AssetReference.deleted_at.is_(None))
|
||||
)
|
||||
|
||||
if name_contains:
|
||||
escaped, esc = escape_sql_like_string(name_contains)
|
||||
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||
|
||||
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
|
||||
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
|
||||
ref_sq = ref_sq.subquery()
|
||||
|
||||
# Count tags across those references
|
||||
q = (
|
||||
select(
|
||||
AssetReferenceTag.tag_name,
|
||||
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
|
||||
)
|
||||
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
|
||||
.group_by(AssetReferenceTag.tag_name)
|
||||
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
|
||||
.limit(limit)
|
||||
)
|
||||
|
||||
rows = session.execute(q).all()
|
||||
return {tag_name: int(cnt) for tag_name, cnt in rows}
|
||||
|
||||
|
||||
def bulk_insert_tags_and_meta(
|
||||
session: Session,
|
||||
tag_rows: list[dict],
|
||||
meta_rows: list[dict],
|
||||
) -> None:
|
||||
"""Batch insert into asset_reference_tags and asset_reference_meta.
|
||||
|
||||
Uses ON CONFLICT DO NOTHING.
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at
|
||||
meta_rows: Dicts with: asset_reference_id, key, ordinal, val_*
|
||||
"""
|
||||
if tag_rows:
|
||||
ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
AssetReferenceTag.asset_reference_id,
|
||||
AssetReferenceTag.tag_name,
|
||||
]
|
||||
)
|
||||
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
|
||||
session.execute(ins_tags, chunk)
|
||||
|
||||
if meta_rows:
|
||||
ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing(
|
||||
index_elements=[
|
||||
AssetReferenceMeta.asset_reference_id,
|
||||
AssetReferenceMeta.key,
|
||||
AssetReferenceMeta.ordinal,
|
||||
]
|
||||
)
|
||||
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):
|
||||
session.execute(ins_meta, chunk)
|
||||
@ -1,62 +0,0 @@
|
||||
from typing import Iterable
|
||||
|
||||
import sqlalchemy
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.dialects import sqlite
|
||||
|
||||
from app.assets.helpers import normalize_tags, utcnow
|
||||
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
|
||||
|
||||
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
.on_conflict_do_nothing(index_elements=[Tag.name])
|
||||
)
|
||||
return session.execute(ins)
|
||||
|
||||
def add_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
origin: str = "automatic",
|
||||
) -> None:
|
||||
select_rows = (
|
||||
sqlalchemy.select(
|
||||
AssetInfo.id.label("asset_info_id"),
|
||||
sqlalchemy.literal("missing").label("tag_name"),
|
||||
sqlalchemy.literal(origin).label("origin"),
|
||||
sqlalchemy.literal(utcnow()).label("added_at"),
|
||||
)
|
||||
.where(AssetInfo.asset_id == asset_id)
|
||||
.where(
|
||||
sqlalchemy.not_(
|
||||
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
|
||||
)
|
||||
)
|
||||
)
|
||||
session.execute(
|
||||
sqlite.insert(AssetInfoTag)
|
||||
.from_select(
|
||||
["asset_info_id", "tag_name", "origin", "added_at"],
|
||||
select_rows,
|
||||
)
|
||||
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
|
||||
)
|
||||
|
||||
def remove_missing_tag_for_asset_id(
|
||||
session: Session,
|
||||
*,
|
||||
asset_id: str,
|
||||
) -> None:
|
||||
session.execute(
|
||||
sqlalchemy.delete(AssetInfoTag).where(
|
||||
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
|
||||
AssetInfoTag.tag_name == "missing",
|
||||
)
|
||||
)
|
||||
@ -1,75 +0,0 @@
|
||||
from blake3 import blake3
|
||||
from typing import IO
|
||||
import os
|
||||
import asyncio
|
||||
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
|
||||
|
||||
# NOTE: this allows hashing different representations of a file-like object
|
||||
def blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""
|
||||
Returns a BLAKE3 hex digest for ``fp``, which may be:
|
||||
- a filename (str/bytes) or PathLike
|
||||
- an open binary file object
|
||||
If ``fp`` is a file object, it must be opened in **binary** mode and support
|
||||
``read``, ``seek``, and ``tell``. The function will seek to the start before
|
||||
reading and will attempt to restore the original position afterward.
|
||||
"""
|
||||
# duck typing to check if input is a file-like object
|
||||
if hasattr(fp, "read"):
|
||||
return _hash_file_obj(fp, chunk_size)
|
||||
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
|
||||
async def blake3_hash_async(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
) -> str:
|
||||
"""Async wrapper for ``blake3_hash_sync``.
|
||||
Uses a worker thread so the event loop remains responsive.
|
||||
"""
|
||||
# If it is a path, open inside the worker thread to keep I/O off the loop.
|
||||
if hasattr(fp, "read"):
|
||||
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
|
||||
|
||||
def _worker() -> str:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
return _hash_file_obj(f, chunk_size)
|
||||
|
||||
return await asyncio.to_thread(_worker)
|
||||
|
||||
|
||||
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
|
||||
"""
|
||||
Hash an already-open binary file object by streaming in chunks.
|
||||
- Seeks to the beginning before reading (if supported).
|
||||
- Restores the original position afterward (if tell/seek are supported).
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
# in case file object is already open and not at the beginning, track so can be restored after hashing
|
||||
orig_pos = file_obj.tell()
|
||||
|
||||
try:
|
||||
# seek to the beginning before reading
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(0)
|
||||
|
||||
h = blake3()
|
||||
while True:
|
||||
chunk = file_obj.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
return h.hexdigest()
|
||||
finally:
|
||||
# restore original position in file object, if needed
|
||||
if orig_pos != 0:
|
||||
file_obj.seek(orig_pos)
|
||||
@ -1,226 +1,42 @@
|
||||
import contextlib
|
||||
import os
|
||||
from decimal import Decimal
|
||||
from aiohttp import web
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Literal, Any
|
||||
|
||||
import folder_paths
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
|
||||
|
||||
def get_query_dict(request: web.Request) -> dict[str, Any]:
|
||||
def select_best_live_path(states: Sequence) -> str:
|
||||
"""
|
||||
Gets a dictionary of query parameters from the request.
|
||||
|
||||
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
|
||||
Return the best on-disk path among cache states:
|
||||
1) Prefer a path that exists with needs_verify == False (already verified).
|
||||
2) Otherwise, pick the first path that exists.
|
||||
3) Otherwise return empty string.
|
||||
"""
|
||||
query_dict = {
|
||||
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
|
||||
for key in request.query.keys()
|
||||
}
|
||||
return query_dict
|
||||
alive = [
|
||||
s
|
||||
for s in states
|
||||
if getattr(s, "file_path", None) and os.path.isfile(s.file_path)
|
||||
]
|
||||
if not alive:
|
||||
return ""
|
||||
for s in alive:
|
||||
if not getattr(s, "needs_verify", False):
|
||||
return s.file_path
|
||||
return alive[0].file_path
|
||||
|
||||
def list_tree(base_dir: str) -> list[str]:
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
|
||||
for name in filenames:
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
|
||||
def prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char in a LIKE prefix.
|
||||
|
||||
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
|
||||
"""Escapes %, _ and the escape char itself in a LIKE prefix.
|
||||
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
|
||||
Returns (escaped_prefix, escape_char).
|
||||
"""
|
||||
s = s.replace(escape, escape + escape) # escape the escape char first
|
||||
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
|
||||
return s, escape
|
||||
|
||||
def fast_asset_file_check(
|
||||
*,
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
sz = int(size_db or 0)
|
||||
if sz > 0:
|
||||
return int(stat_result.st_size) == sz
|
||||
return True
|
||||
|
||||
def utcnow() -> datetime:
|
||||
def get_utc_now() -> datetime:
|
||||
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
|
||||
return datetime.now(timezone.utc).replace(tzinfo=None)
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
|
||||
|
||||
We trust `folder_paths.folder_names_and_paths` and include a category if
|
||||
*any* of its base paths lies under the Comfy `models_dir`.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
models_root = os.path.abspath(folder_paths.models_dir)
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
|
||||
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
root = tags[0]
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
else:
|
||||
base_dir = os.path.abspath(
|
||||
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
|
||||
)
|
||||
raw_subdirs = tags[1:]
|
||||
for i in raw_subdirs:
|
||||
if i in (".", ".."):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
def ensure_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = os.path.abspath(candidate)
|
||||
base_abs = os.path.abspath(base)
|
||||
try:
|
||||
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
|
||||
raise ValueError("destination escapes base directory")
|
||||
except Exception:
|
||||
raise ValueError("invalid destination path")
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Given an absolute or relative file path, determine which root category the path belongs to:
|
||||
- 'input' if the file resides under `folder_paths.get_input_directory()`
|
||||
- 'output' if the file resides under `folder_paths.get_output_directory()`
|
||||
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
For 'models', the relative path is prefixed with the category name:
|
||||
e.g. ('models', 'vae/test/sub/ae.safetensors')
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _is_within(child: str, parent: str) -> bool:
|
||||
try:
|
||||
return os.path.commonpath([child, parent]) == parent
|
||||
except Exception:
|
||||
return False
|
||||
|
||||
def _rel(child: str, parent: str) -> str:
|
||||
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _is_within(fp_abs, input_base):
|
||||
return "input", _rel(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _is_within(fp_abs, output_base):
|
||||
return "output", _rel(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return a tuple (name, tags) derived from a filesystem path.
|
||||
|
||||
Semantics:
|
||||
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
|
||||
- The returned `name` is the base filename with extension from the relative path.
|
||||
- The returned `tags` are:
|
||||
[root_category] + parent folders of the relative path (in order)
|
||||
For 'models', this means:
|
||||
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
|
||||
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
|
||||
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
|
||||
|
||||
Raises:
|
||||
ValueError: if the path does not belong to input, output, or configured model bases.
|
||||
"""
|
||||
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
|
||||
def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
"""
|
||||
@ -228,85 +44,22 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
|
||||
- Stripping whitespace and converting to lowercase.
|
||||
- Removing duplicates.
|
||||
"""
|
||||
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
|
||||
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
with contextlib.suppress(Exception):
|
||||
if os.path.commonpath([abs_path, base_abs]) == base_abs:
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
def is_scalar(v):
|
||||
if v is None:
|
||||
return True
|
||||
if isinstance(v, bool):
|
||||
return True
|
||||
if isinstance(v, (int, float, Decimal, str)):
|
||||
return True
|
||||
return False
|
||||
def validate_blake3_hash(s: str) -> str:
|
||||
"""Validate and normalize a blake3 hash string.
|
||||
|
||||
def project_kv(key: str, value):
|
||||
Returns canonical 'blake3:<hex>' or raises ValueError.
|
||||
"""
|
||||
Turn a metadata key/value into typed projection rows.
|
||||
Returns list[dict] with keys:
|
||||
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
|
||||
"""
|
||||
rows: list[dict] = []
|
||||
|
||||
def _null_row(ordinal: int) -> dict:
|
||||
return {
|
||||
"key": key, "ordinal": ordinal,
|
||||
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
|
||||
}
|
||||
|
||||
if value is None:
|
||||
rows.append(_null_row(0))
|
||||
return rows
|
||||
|
||||
if is_scalar(value):
|
||||
if isinstance(value, bool):
|
||||
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
|
||||
elif isinstance(value, (int, float, Decimal)):
|
||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||
rows.append({"key": key, "ordinal": 0, "val_num": num})
|
||||
elif isinstance(value, str):
|
||||
rows.append({"key": key, "ordinal": 0, "val_str": value})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
|
||||
if isinstance(value, list):
|
||||
if all(is_scalar(x) for x in value):
|
||||
for i, x in enumerate(value):
|
||||
if x is None:
|
||||
rows.append(_null_row(i))
|
||||
elif isinstance(x, bool):
|
||||
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
|
||||
elif isinstance(x, (int, float, Decimal)):
|
||||
num = x if isinstance(x, Decimal) else Decimal(str(x))
|
||||
rows.append({"key": key, "ordinal": i, "val_num": num})
|
||||
elif isinstance(x, str):
|
||||
rows.append({"key": key, "ordinal": i, "val_str": x})
|
||||
else:
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
for i, x in enumerate(value):
|
||||
rows.append({"key": key, "ordinal": i, "val_json": x})
|
||||
return rows
|
||||
|
||||
rows.append({"key": key, "ordinal": 0, "val_json": value})
|
||||
return rows
|
||||
s = s.strip().lower()
|
||||
if not s or ":" not in s:
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
algo, digest = s.split(":", 1)
|
||||
if (
|
||||
algo != "blake3"
|
||||
or len(digest) != 64
|
||||
or any(c for c in digest if c not in "0123456789abcdef")
|
||||
):
|
||||
raise ValueError("hash must be 'blake3:<hex>'")
|
||||
return f"{algo}:{digest}"
|
||||
|
||||
@ -1,516 +0,0 @@
|
||||
import os
|
||||
import mimetypes
|
||||
import contextlib
|
||||
from typing import Sequence
|
||||
|
||||
from app.database.db import create_session
|
||||
from app.assets.api import schemas_out, schemas_in
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
asset_info_exists_for_asset_id,
|
||||
get_asset_by_hash,
|
||||
get_asset_info_by_id,
|
||||
fetch_asset_info_asset_and_tags,
|
||||
fetch_asset_info_and_asset,
|
||||
create_asset_info_for_existing_asset,
|
||||
touch_asset_info_by_id,
|
||||
update_asset_info_full,
|
||||
delete_asset_info_by_id,
|
||||
list_cache_states_by_asset_id,
|
||||
list_asset_infos_page,
|
||||
list_tags_with_usage,
|
||||
get_asset_tags,
|
||||
add_tags_to_asset_info,
|
||||
remove_tags_from_asset_info,
|
||||
pick_best_live_path,
|
||||
ingest_fs_asset,
|
||||
set_asset_info_preview,
|
||||
)
|
||||
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
|
||||
from app.assets.database.models import Asset
|
||||
|
||||
|
||||
def _safe_sort_field(requested: str | None) -> str:
|
||||
if not requested:
|
||||
return "created_at"
|
||||
v = requested.lower()
|
||||
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
|
||||
return v
|
||||
return "created_at"
|
||||
|
||||
|
||||
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
|
||||
st = os.stat(path, follow_symlinks=True)
|
||||
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
|
||||
|
||||
|
||||
def _safe_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
if n:
|
||||
return n
|
||||
return fallback
|
||||
|
||||
|
||||
def asset_exists(*, asset_hash: str) -> bool:
|
||||
"""
|
||||
Check if an asset with a given hash exists in database.
|
||||
"""
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def list_assets(
|
||||
*,
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetsList:
|
||||
sort = _safe_sort_field(sort)
|
||||
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
|
||||
|
||||
with create_session() as session:
|
||||
infos, tag_map, total = list_asset_infos_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
summaries: list[schemas_out.AssetSummary] = []
|
||||
for info in infos:
|
||||
asset = info.asset
|
||||
tags = tag_map.get(info.id, [])
|
||||
summaries.append(
|
||||
schemas_out.AssetSummary(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
created_at=info.created_at,
|
||||
updated_at=info.updated_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
)
|
||||
|
||||
return schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=total,
|
||||
has_more=(offset + len(summaries)) < total,
|
||||
)
|
||||
|
||||
|
||||
def get_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
info, asset, tag_names = res
|
||||
preview_id = info.preview_id
|
||||
|
||||
return schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_content_for_download(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
owner_id: str = "",
|
||||
) -> tuple[str, str, str]:
|
||||
with create_session() as session:
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
|
||||
info, asset = pair
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = pick_best_live_path(states)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError
|
||||
|
||||
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
session.commit()
|
||||
|
||||
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
|
||||
download_name = info.name or os.path.basename(abs_path)
|
||||
return abs_path, ctype, download_name
|
||||
|
||||
|
||||
def upload_asset_from_temp_path(
|
||||
spec: schemas_in.UploadAssetSpec,
|
||||
*,
|
||||
temp_path: str,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_asset_hash: str | None = None,
|
||||
) -> schemas_out.AssetCreated:
|
||||
"""
|
||||
Create new asset or update existing asset from a temporary file path.
|
||||
"""
|
||||
try:
|
||||
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
|
||||
import app.assets.hashing as hashing
|
||||
digest = hashing.blake3_hash(temp_path)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
|
||||
raise ValueError("HASH_MISMATCH")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
session.commit()
|
||||
|
||||
return schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=existing.hash,
|
||||
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
|
||||
mime_type=existing.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or spec.name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
ensure_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
with create_session() as session:
|
||||
result = ingest_fs_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=None,
|
||||
user_metadata=spec.user_metadata or {},
|
||||
tags=spec.tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
info_id = result["asset_info_id"]
|
||||
if not info_id:
|
||||
raise RuntimeError("failed to create asset metadata")
|
||||
|
||||
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
info, asset = pair
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
created_result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=result["asset_created"],
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return created_result
|
||||
|
||||
|
||||
def update_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetUpdated:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
info = update_asset_info_full(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
name=name,
|
||||
tags=tags,
|
||||
user_metadata=user_metadata,
|
||||
tag_origin="manual",
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
|
||||
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
|
||||
result = schemas_out.AssetUpdated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=info.asset.hash if info.asset else None,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
updated_at=info.updated_at,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
preview_asset_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetDetail:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
set_asset_info_preview(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
preview_asset_id=preview_asset_id,
|
||||
)
|
||||
|
||||
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not res:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
info, asset, tags = res
|
||||
result = schemas_out.AssetDetail(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash if asset else None,
|
||||
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
|
||||
mime_type=asset.mime_type if asset else None,
|
||||
tags=tags,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
asset_id = info_row.asset_id if info_row else None
|
||||
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not delete_content_if_orphan or not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
|
||||
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
return True
|
||||
|
||||
|
||||
def create_asset_from_hash(
|
||||
*,
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.AssetCreated | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
||||
if not asset:
|
||||
return None
|
||||
|
||||
info = create_asset_info_for_existing_asset(
|
||||
session,
|
||||
asset_hash=canonical,
|
||||
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
)
|
||||
tag_names = get_asset_tags(session, asset_info_id=info.id)
|
||||
result = schemas_out.AssetCreated(
|
||||
id=info.id,
|
||||
name=info.name,
|
||||
asset_hash=asset.hash,
|
||||
size=int(asset.size_bytes),
|
||||
mime_type=asset.mime_type,
|
||||
tags=tag_names,
|
||||
user_metadata=info.user_metadata or {},
|
||||
preview_id=info.preview_id,
|
||||
created_at=info.created_at,
|
||||
last_access_time=info.last_access_time,
|
||||
created_new=False,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def add_tags_to_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsAdd:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
data = add_tags_to_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
asset_info_row=info_row,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsAdd(**data)
|
||||
|
||||
|
||||
def remove_tags_from_asset(
|
||||
*,
|
||||
asset_info_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsRemove:
|
||||
with create_session() as session:
|
||||
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
|
||||
if not info_row:
|
||||
raise ValueError(f"AssetInfo {asset_info_id} not found")
|
||||
if info_row.owner_id and info_row.owner_id != owner_id:
|
||||
raise PermissionError("not owner")
|
||||
|
||||
data = remove_tags_from_asset_info(
|
||||
session,
|
||||
asset_info_id=asset_info_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
return schemas_out.TagsRemove(**data)
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> schemas_out.TagsList:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
|
||||
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)
|
||||
@ -1,263 +1,567 @@
|
||||
import contextlib
|
||||
import time
|
||||
import logging
|
||||
import os
|
||||
import sqlalchemy
|
||||
from pathlib import Path
|
||||
from typing import Callable, Literal, TypedDict
|
||||
|
||||
import folder_paths
|
||||
from app.database.db import create_session, dependencies_available
|
||||
from app.assets.helpers import (
|
||||
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
|
||||
list_tree,prefixes_for_root, escape_like_prefix,
|
||||
RootType
|
||||
from app.assets.database.queries import (
|
||||
add_missing_tag_for_asset_id,
|
||||
bulk_update_enrichment_level,
|
||||
bulk_update_is_missing,
|
||||
bulk_update_needs_verify,
|
||||
delete_orphaned_seed_asset,
|
||||
delete_references_by_ids,
|
||||
ensure_tags_exist,
|
||||
get_asset_by_hash,
|
||||
get_references_for_prefixes,
|
||||
get_unenriched_references,
|
||||
mark_references_missing_outside_prefixes,
|
||||
reassign_asset_references,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_system_metadata,
|
||||
update_asset_hash_and_mime,
|
||||
)
|
||||
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
|
||||
from app.assets.database.bulk_ops import seed_from_paths_batch
|
||||
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
|
||||
from app.assets.services.bulk_ingest import (
|
||||
SeedAssetSpec,
|
||||
batch_insert_seed_assets,
|
||||
)
|
||||
from app.assets.services.file_utils import (
|
||||
get_mtime_ns,
|
||||
is_visible,
|
||||
list_files_recursively,
|
||||
verify_file_unchanged,
|
||||
)
|
||||
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
|
||||
from app.assets.services.metadata_extract import extract_file_metadata
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_comfy_models_folders,
|
||||
get_name_and_tags_from_asset_path,
|
||||
)
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
|
||||
"""
|
||||
Scan the given roots and seed the assets into the database.
|
||||
"""
|
||||
if not dependencies_available():
|
||||
if enable_logging:
|
||||
logging.warning("Database dependencies not available, skipping assets scan")
|
||||
return
|
||||
t_start = time.perf_counter()
|
||||
created = 0
|
||||
skipped_existing = 0
|
||||
orphans_pruned = 0
|
||||
paths: list[str] = []
|
||||
try:
|
||||
existing_paths: set[str] = set()
|
||||
for r in roots:
|
||||
try:
|
||||
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
|
||||
if survivors:
|
||||
existing_paths.update(survivors)
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", r, e)
|
||||
class _RefInfo(TypedDict):
|
||||
ref_id: str
|
||||
file_path: str
|
||||
exists: bool
|
||||
stat_unchanged: bool
|
||||
needs_verify: bool
|
||||
|
||||
try:
|
||||
orphans_pruned = _prune_orphaned_assets(roots)
|
||||
except Exception as e:
|
||||
logging.exception("orphan pruning failed: %s", e)
|
||||
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_tree(folder_paths.get_output_directory()))
|
||||
class _AssetAccumulator(TypedDict):
|
||||
hash: str | None
|
||||
size_db: int
|
||||
refs: list[_RefInfo]
|
||||
|
||||
specs: list[dict] = []
|
||||
tag_pool: set[str] = set()
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped_existing += 1
|
||||
|
||||
RootType = Literal["models", "input", "output"]
|
||||
|
||||
|
||||
def get_prefixes_for_root(root: RootType) -> list[str]:
|
||||
if root == "models":
|
||||
bases: list[str] = []
|
||||
for _bucket, paths in get_comfy_models_folders():
|
||||
bases.extend(paths)
|
||||
return [os.path.abspath(p) for p in bases]
|
||||
if root == "input":
|
||||
return [os.path.abspath(folder_paths.get_input_directory())]
|
||||
if root == "output":
|
||||
return [os.path.abspath(folder_paths.get_output_directory())]
|
||||
return []
|
||||
|
||||
|
||||
def get_all_known_prefixes() -> list[str]:
|
||||
"""Get all known asset prefixes across all root types."""
|
||||
all_roots: tuple[RootType, ...] = ("models", "input", "output")
|
||||
return [p for root in all_roots for p in get_prefixes_for_root(root)]
|
||||
|
||||
|
||||
def collect_models_files() -> list[str]:
|
||||
out: list[str] = []
|
||||
for folder_name, bases in get_comfy_models_folders():
|
||||
rel_files = folder_paths.get_filename_list(folder_name) or []
|
||||
for rel_path in rel_files:
|
||||
if not all(is_visible(part) for part in Path(rel_path).parts):
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=False)
|
||||
except OSError:
|
||||
abs_path = folder_paths.get_full_path(folder_name, rel_path)
|
||||
if not abs_path:
|
||||
continue
|
||||
# skip empty files
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": compute_relative_filename(abs_p),
|
||||
}
|
||||
)
|
||||
for t in tags:
|
||||
tag_pool.add(t)
|
||||
# if no file specs, nothing to do
|
||||
if not specs:
|
||||
return
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
|
||||
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
|
||||
created += result["inserted_infos"]
|
||||
sess.commit()
|
||||
finally:
|
||||
if enable_logging:
|
||||
logging.info(
|
||||
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
|
||||
roots,
|
||||
time.perf_counter() - t_start,
|
||||
created,
|
||||
skipped_existing,
|
||||
orphans_pruned,
|
||||
len(paths),
|
||||
)
|
||||
abs_path = os.path.abspath(abs_path)
|
||||
allowed = False
|
||||
abs_p = Path(abs_path)
|
||||
for b in bases:
|
||||
if abs_p.is_relative_to(os.path.abspath(b)):
|
||||
allowed = True
|
||||
break
|
||||
if allowed:
|
||||
out.append(abs_path)
|
||||
return out
|
||||
|
||||
|
||||
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
|
||||
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
|
||||
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
|
||||
if not all_prefixes:
|
||||
return 0
|
||||
|
||||
def make_prefix_condition(prefix: str):
|
||||
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
|
||||
|
||||
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
|
||||
|
||||
orphan_subq = (
|
||||
sqlalchemy.select(Asset.id)
|
||||
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
|
||||
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
|
||||
).scalar_subquery()
|
||||
|
||||
with create_session() as sess:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
|
||||
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
|
||||
sess.commit()
|
||||
return result.rowcount
|
||||
|
||||
|
||||
def _fast_db_consistency_pass(
|
||||
def sync_references_with_filesystem(
|
||||
session,
|
||||
root: RootType,
|
||||
*,
|
||||
collect_existing_paths: bool = False,
|
||||
update_missing_tags: bool = False,
|
||||
) -> set[str] | None:
|
||||
"""Fast DB+FS pass for a root:
|
||||
- Toggle needs_verify per state using fast check
|
||||
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
|
||||
- For seed assets with all states missing: delete Asset and its AssetInfos
|
||||
- Optionally add/remove 'missing' tags based on fast-ok in this root
|
||||
- Optionally return surviving absolute paths
|
||||
"""Reconcile asset references with filesystem for a root.
|
||||
|
||||
- Toggle needs_verify per reference using mtime/size stat check
|
||||
- For hashed assets with at least one stat-unchanged ref: delete stale missing refs
|
||||
- For seed assets with all refs missing: delete Asset and its references
|
||||
- Optionally add/remove 'missing' tags based on stat check in this root
|
||||
- Optionally return surviving absolute paths
|
||||
|
||||
Args:
|
||||
session: Database session
|
||||
root: Root type to scan
|
||||
collect_existing_paths: If True, return set of surviving file paths
|
||||
update_missing_tags: If True, update 'missing' tags based on file status
|
||||
|
||||
Returns:
|
||||
Set of surviving absolute paths if collect_existing_paths=True, else None
|
||||
"""
|
||||
prefixes = prefixes_for_root(root)
|
||||
prefixes = get_prefixes_for_root(root)
|
||||
if not prefixes:
|
||||
return set() if collect_existing_paths else None
|
||||
|
||||
conds = []
|
||||
for p in prefixes:
|
||||
base = os.path.abspath(p)
|
||||
if not base.endswith(os.sep):
|
||||
base += os.sep
|
||||
escaped, esc = escape_like_prefix(base)
|
||||
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
|
||||
rows = get_references_for_prefixes(
|
||||
session, prefixes, include_missing=update_missing_tags
|
||||
)
|
||||
|
||||
by_asset: dict[str, _AssetAccumulator] = {}
|
||||
for row in rows:
|
||||
acc = by_asset.get(row.asset_id)
|
||||
if acc is None:
|
||||
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
|
||||
by_asset[row.asset_id] = acc
|
||||
|
||||
stat_unchanged = False
|
||||
try:
|
||||
exists = True
|
||||
stat_unchanged = verify_file_unchanged(
|
||||
mtime_db=row.mtime_ns,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(row.file_path, follow_symlinks=True),
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except PermissionError:
|
||||
exists = True
|
||||
logging.debug("Permission denied accessing %s", row.file_path)
|
||||
except OSError as e:
|
||||
exists = False
|
||||
logging.debug("OSError checking %s: %s", row.file_path, e)
|
||||
|
||||
acc["refs"].append(
|
||||
{
|
||||
"ref_id": row.reference_id,
|
||||
"file_path": row.file_path,
|
||||
"exists": exists,
|
||||
"stat_unchanged": stat_unchanged,
|
||||
"needs_verify": row.needs_verify,
|
||||
}
|
||||
)
|
||||
|
||||
to_set_verify: list[str] = []
|
||||
to_clear_verify: list[str] = []
|
||||
stale_ref_ids: list[str] = []
|
||||
to_mark_missing: list[str] = []
|
||||
to_clear_missing: list[str] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
refs = acc["refs"]
|
||||
any_unchanged = any(r["stat_unchanged"] for r in refs)
|
||||
all_missing = all(not r["exists"] for r in refs)
|
||||
|
||||
for r in refs:
|
||||
if not r["exists"]:
|
||||
to_mark_missing.append(r["ref_id"])
|
||||
continue
|
||||
if r["stat_unchanged"]:
|
||||
to_clear_missing.append(r["ref_id"])
|
||||
if r["needs_verify"]:
|
||||
to_clear_verify.append(r["ref_id"])
|
||||
if not r["stat_unchanged"] and not r["needs_verify"]:
|
||||
to_set_verify.append(r["ref_id"])
|
||||
|
||||
if a_hash is None:
|
||||
if refs and all_missing:
|
||||
delete_orphaned_seed_asset(session, aid)
|
||||
else:
|
||||
for r in refs:
|
||||
if r["exists"]:
|
||||
survivors.add(os.path.abspath(r["file_path"]))
|
||||
continue
|
||||
|
||||
if any_unchanged:
|
||||
for r in refs:
|
||||
if not r["exists"]:
|
||||
stale_ref_ids.append(r["ref_id"])
|
||||
if update_missing_tags:
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=aid)
|
||||
except Exception as e:
|
||||
logging.warning(
|
||||
"Failed to remove missing tag for asset %s: %s", aid, e
|
||||
)
|
||||
elif update_missing_tags:
|
||||
try:
|
||||
add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
|
||||
except Exception as e:
|
||||
logging.warning("Failed to add missing tag for asset %s: %s", aid, e)
|
||||
|
||||
for r in refs:
|
||||
if r["exists"]:
|
||||
survivors.add(os.path.abspath(r["file_path"]))
|
||||
|
||||
delete_references_by_ids(session, stale_ref_ids)
|
||||
stale_set = set(stale_ref_ids)
|
||||
to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set]
|
||||
bulk_update_is_missing(session, to_mark_missing, value=True)
|
||||
bulk_update_is_missing(session, to_clear_missing, value=False)
|
||||
bulk_update_needs_verify(session, to_set_verify, value=True)
|
||||
bulk_update_needs_verify(session, to_clear_verify, value=False)
|
||||
|
||||
return survivors if collect_existing_paths else None
|
||||
|
||||
|
||||
def sync_root_safely(root: RootType) -> set[str]:
|
||||
"""Sync a single root's references with the filesystem.
|
||||
|
||||
Returns survivors (existing paths) or empty set on failure.
|
||||
"""
|
||||
try:
|
||||
with create_session() as sess:
|
||||
survivors = sync_references_with_filesystem(
|
||||
sess,
|
||||
root,
|
||||
collect_existing_paths=True,
|
||||
update_missing_tags=True,
|
||||
)
|
||||
sess.commit()
|
||||
return survivors or set()
|
||||
except Exception as e:
|
||||
logging.exception("fast DB scan failed for %s: %s", root, e)
|
||||
return set()
|
||||
|
||||
|
||||
def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int:
|
||||
"""Mark references as missing when outside the given prefixes.
|
||||
|
||||
This is a non-destructive soft-delete. Returns count marked or 0 on failure.
|
||||
"""
|
||||
try:
|
||||
with create_session() as sess:
|
||||
count = mark_references_missing_outside_prefixes(sess, prefixes)
|
||||
sess.commit()
|
||||
return count
|
||||
except Exception as e:
|
||||
logging.exception("marking missing assets failed: %s", e)
|
||||
return 0
|
||||
|
||||
|
||||
def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
|
||||
"""Collect all file paths for the given roots."""
|
||||
paths: list[str] = []
|
||||
if "models" in roots:
|
||||
paths.extend(collect_models_files())
|
||||
if "input" in roots:
|
||||
paths.extend(list_files_recursively(folder_paths.get_input_directory()))
|
||||
if "output" in roots:
|
||||
paths.extend(list_files_recursively(folder_paths.get_output_directory()))
|
||||
return paths
|
||||
|
||||
|
||||
def build_asset_specs(
|
||||
paths: list[str],
|
||||
existing_paths: set[str],
|
||||
enable_metadata_extraction: bool = True,
|
||||
compute_hashes: bool = False,
|
||||
) -> tuple[list[SeedAssetSpec], set[str], int]:
|
||||
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count).
|
||||
|
||||
Args:
|
||||
paths: List of file paths to process
|
||||
existing_paths: Set of paths that already exist in the database
|
||||
enable_metadata_extraction: If True, extract tier 1 & 2 metadata
|
||||
compute_hashes: If True, compute blake3 hashes (slow for large files)
|
||||
"""
|
||||
specs: list[SeedAssetSpec] = []
|
||||
tag_pool: set[str] = set()
|
||||
skipped = 0
|
||||
|
||||
for p in paths:
|
||||
abs_p = os.path.abspath(p)
|
||||
if abs_p in existing_paths:
|
||||
skipped += 1
|
||||
continue
|
||||
try:
|
||||
stat_p = os.stat(abs_p, follow_symlinks=True)
|
||||
except OSError:
|
||||
continue
|
||||
if not stat_p.st_size:
|
||||
continue
|
||||
name, tags = get_name_and_tags_from_asset_path(abs_p)
|
||||
rel_fname = compute_relative_filename(abs_p)
|
||||
|
||||
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
|
||||
metadata = None
|
||||
if enable_metadata_extraction:
|
||||
metadata = extract_file_metadata(
|
||||
abs_p,
|
||||
stat_result=stat_p,
|
||||
relative_filename=rel_fname,
|
||||
)
|
||||
|
||||
# Compute hash if requested
|
||||
asset_hash: str | None = None
|
||||
if compute_hashes:
|
||||
try:
|
||||
digest, _ = compute_blake3_hash(abs_p)
|
||||
asset_hash = "blake3:" + digest
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", abs_p, e)
|
||||
|
||||
mime_type = metadata.content_type if metadata else None
|
||||
specs.append(
|
||||
{
|
||||
"abs_path": abs_p,
|
||||
"size_bytes": stat_p.st_size,
|
||||
"mtime_ns": get_mtime_ns(stat_p),
|
||||
"info_name": name,
|
||||
"tags": tags,
|
||||
"fname": rel_fname,
|
||||
"metadata": metadata,
|
||||
"hash": asset_hash,
|
||||
"mime_type": mime_type,
|
||||
}
|
||||
)
|
||||
tag_pool.update(tags)
|
||||
|
||||
return specs, tag_pool, skipped
|
||||
|
||||
|
||||
|
||||
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
|
||||
"""Insert asset specs into database, returning count of created refs."""
|
||||
if not specs:
|
||||
return 0
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
|
||||
sess.commit()
|
||||
return result.inserted_refs
|
||||
|
||||
|
||||
# Enrichment level constants
|
||||
ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only
|
||||
ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type)
|
||||
ENRICHMENT_HASHED = 2 # Hash computed (blake3)
|
||||
|
||||
|
||||
def get_unenriched_assets_for_roots(
|
||||
roots: tuple[RootType, ...],
|
||||
max_level: int = ENRICHMENT_STUB,
|
||||
limit: int = 1000,
|
||||
) -> list:
|
||||
"""Get assets that need enrichment for the given roots.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
max_level: Maximum enrichment level to include
|
||||
limit: Maximum number of rows to return
|
||||
|
||||
Returns:
|
||||
List of UnenrichedReferenceRow
|
||||
"""
|
||||
prefixes: list[str] = []
|
||||
for root in roots:
|
||||
prefixes.extend(get_prefixes_for_root(root))
|
||||
|
||||
if not prefixes:
|
||||
return []
|
||||
|
||||
with create_session() as sess:
|
||||
rows = (
|
||||
sess.execute(
|
||||
sqlalchemy.select(
|
||||
AssetCacheState.id,
|
||||
AssetCacheState.file_path,
|
||||
AssetCacheState.mtime_ns,
|
||||
AssetCacheState.needs_verify,
|
||||
AssetCacheState.asset_id,
|
||||
Asset.hash,
|
||||
Asset.size_bytes,
|
||||
)
|
||||
.join(Asset, Asset.id == AssetCacheState.asset_id)
|
||||
.where(sqlalchemy.or_(*conds))
|
||||
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
|
||||
return get_unenriched_references(
|
||||
sess, prefixes, max_level=max_level, limit=limit
|
||||
)
|
||||
|
||||
|
||||
def enrich_asset(
|
||||
session,
|
||||
file_path: str,
|
||||
reference_id: str,
|
||||
asset_id: str,
|
||||
extract_metadata: bool = True,
|
||||
compute_hash: bool = False,
|
||||
interrupt_check: Callable[[], bool] | None = None,
|
||||
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
|
||||
) -> int:
|
||||
"""Enrich a single asset with metadata and/or hash.
|
||||
|
||||
Args:
|
||||
session: Database session (caller manages lifecycle)
|
||||
file_path: Absolute path to the file
|
||||
reference_id: ID of the reference to update
|
||||
asset_id: ID of the asset to update (for mime_type and hash)
|
||||
extract_metadata: If True, extract safetensors header and mime type
|
||||
compute_hash: If True, compute blake3 hash
|
||||
interrupt_check: Optional non-blocking callable that returns True if
|
||||
the operation should be interrupted (e.g. paused or cancelled)
|
||||
hash_checkpoints: Optional dict for saving/restoring hash progress
|
||||
across interruptions, keyed by file path
|
||||
|
||||
Returns:
|
||||
New enrichment level achieved
|
||||
"""
|
||||
new_level = ENRICHMENT_STUB
|
||||
|
||||
try:
|
||||
stat_p = os.stat(file_path, follow_symlinks=True)
|
||||
except OSError:
|
||||
return new_level
|
||||
|
||||
rel_fname = compute_relative_filename(file_path)
|
||||
mime_type: str | None = None
|
||||
metadata = None
|
||||
|
||||
if extract_metadata:
|
||||
metadata = extract_file_metadata(
|
||||
file_path,
|
||||
stat_result=stat_p,
|
||||
relative_filename=rel_fname,
|
||||
)
|
||||
if metadata:
|
||||
mime_type = metadata.content_type
|
||||
new_level = ENRICHMENT_METADATA
|
||||
|
||||
full_hash: str | None = None
|
||||
if compute_hash:
|
||||
try:
|
||||
mtime_before = get_mtime_ns(stat_p)
|
||||
size_before = stat_p.st_size
|
||||
|
||||
# Restore checkpoint if available and file unchanged
|
||||
checkpoint = None
|
||||
if hash_checkpoints is not None:
|
||||
checkpoint = hash_checkpoints.get(file_path)
|
||||
if checkpoint is not None:
|
||||
cur_stat = os.stat(file_path, follow_symlinks=True)
|
||||
if (checkpoint.mtime_ns != get_mtime_ns(cur_stat)
|
||||
or checkpoint.file_size != cur_stat.st_size):
|
||||
checkpoint = None
|
||||
hash_checkpoints.pop(file_path, None)
|
||||
else:
|
||||
mtime_before = get_mtime_ns(cur_stat)
|
||||
|
||||
digest, new_checkpoint = compute_blake3_hash(
|
||||
file_path,
|
||||
interrupt_check=interrupt_check,
|
||||
checkpoint=checkpoint,
|
||||
)
|
||||
).all()
|
||||
|
||||
by_asset: dict[str, dict] = {}
|
||||
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
|
||||
acc = by_asset.get(aid)
|
||||
if acc is None:
|
||||
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
|
||||
by_asset[aid] = acc
|
||||
if digest is None:
|
||||
# Interrupted — save checkpoint for later resumption
|
||||
if hash_checkpoints is not None and new_checkpoint is not None:
|
||||
new_checkpoint.mtime_ns = mtime_before
|
||||
new_checkpoint.file_size = size_before
|
||||
hash_checkpoints[file_path] = new_checkpoint
|
||||
return new_level
|
||||
|
||||
# Completed — clear any saved checkpoint
|
||||
if hash_checkpoints is not None:
|
||||
hash_checkpoints.pop(file_path, None)
|
||||
|
||||
stat_after = os.stat(file_path, follow_symlinks=True)
|
||||
mtime_after = get_mtime_ns(stat_after)
|
||||
if mtime_before != mtime_after:
|
||||
logging.warning("File modified during hashing, discarding hash: %s", file_path)
|
||||
else:
|
||||
full_hash = f"blake3:{digest}"
|
||||
metadata_ok = not extract_metadata or metadata is not None
|
||||
if metadata_ok:
|
||||
new_level = ENRICHMENT_HASHED
|
||||
except Exception as e:
|
||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||
|
||||
if extract_metadata and metadata:
|
||||
system_metadata = metadata.to_user_metadata()
|
||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||
|
||||
if full_hash:
|
||||
existing = get_asset_by_hash(session, full_hash)
|
||||
if existing and existing.id != asset_id:
|
||||
reassign_asset_references(session, asset_id, existing.id, reference_id)
|
||||
delete_orphaned_seed_asset(session, asset_id)
|
||||
if mime_type:
|
||||
update_asset_hash_and_mime(session, existing.id, mime_type=mime_type)
|
||||
else:
|
||||
update_asset_hash_and_mime(session, asset_id, full_hash, mime_type)
|
||||
elif mime_type:
|
||||
update_asset_hash_and_mime(session, asset_id, mime_type=mime_type)
|
||||
|
||||
bulk_update_enrichment_level(session, [reference_id], new_level)
|
||||
session.commit()
|
||||
|
||||
return new_level
|
||||
|
||||
|
||||
def enrich_assets_batch(
|
||||
rows: list,
|
||||
extract_metadata: bool = True,
|
||||
compute_hash: bool = False,
|
||||
interrupt_check: Callable[[], bool] | None = None,
|
||||
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
|
||||
) -> tuple[int, list[str]]:
|
||||
"""Enrich a batch of assets.
|
||||
|
||||
Uses a single DB session for the entire batch, committing after each
|
||||
individual asset to avoid long-held transactions while eliminating
|
||||
per-asset session creation overhead.
|
||||
|
||||
Args:
|
||||
rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots
|
||||
extract_metadata: If True, extract metadata for each asset
|
||||
compute_hash: If True, compute hash for each asset
|
||||
interrupt_check: Optional non-blocking callable that returns True if
|
||||
the operation should be interrupted (e.g. paused or cancelled)
|
||||
hash_checkpoints: Optional dict for saving/restoring hash progress
|
||||
across interruptions, keyed by file path
|
||||
|
||||
Returns:
|
||||
Tuple of (enriched_count, failed_reference_ids)
|
||||
"""
|
||||
enriched = 0
|
||||
failed_ids: list[str] = []
|
||||
|
||||
with create_session() as sess:
|
||||
for row in rows:
|
||||
if interrupt_check is not None and interrupt_check():
|
||||
break
|
||||
|
||||
fast_ok = False
|
||||
try:
|
||||
exists = True
|
||||
fast_ok = fast_asset_file_check(
|
||||
mtime_db=mtime_db,
|
||||
size_db=acc["size_db"],
|
||||
stat_result=os.stat(fp, follow_symlinks=True),
|
||||
new_level = enrich_asset(
|
||||
sess,
|
||||
file_path=row.file_path,
|
||||
reference_id=row.reference_id,
|
||||
asset_id=row.asset_id,
|
||||
extract_metadata=extract_metadata,
|
||||
compute_hash=compute_hash,
|
||||
interrupt_check=interrupt_check,
|
||||
hash_checkpoints=hash_checkpoints,
|
||||
)
|
||||
except FileNotFoundError:
|
||||
exists = False
|
||||
except OSError:
|
||||
exists = False
|
||||
|
||||
acc["states"].append({
|
||||
"sid": sid,
|
||||
"fp": fp,
|
||||
"exists": exists,
|
||||
"fast_ok": fast_ok,
|
||||
"needs_verify": bool(needs_verify),
|
||||
})
|
||||
|
||||
to_set_verify: list[int] = []
|
||||
to_clear_verify: list[int] = []
|
||||
stale_state_ids: list[int] = []
|
||||
survivors: set[str] = set()
|
||||
|
||||
for aid, acc in by_asset.items():
|
||||
a_hash = acc["hash"]
|
||||
states = acc["states"]
|
||||
any_fast_ok = any(s["fast_ok"] for s in states)
|
||||
all_missing = all(not s["exists"] for s in states)
|
||||
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
continue
|
||||
if s["fast_ok"] and s["needs_verify"]:
|
||||
to_clear_verify.append(s["sid"])
|
||||
if not s["fast_ok"] and not s["needs_verify"]:
|
||||
to_set_verify.append(s["sid"])
|
||||
|
||||
if a_hash is None:
|
||||
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
|
||||
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
|
||||
asset = sess.get(Asset, aid)
|
||||
if asset:
|
||||
sess.delete(asset)
|
||||
if new_level > row.enrichment_level:
|
||||
enriched += 1
|
||||
else:
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
continue
|
||||
failed_ids.append(row.reference_id)
|
||||
except Exception as e:
|
||||
logging.warning("Failed to enrich %s: %s", row.file_path, e)
|
||||
sess.rollback()
|
||||
failed_ids.append(row.reference_id)
|
||||
|
||||
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
|
||||
for s in states:
|
||||
if not s["exists"]:
|
||||
stale_state_ids.append(s["sid"])
|
||||
if update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
remove_missing_tag_for_asset_id(sess, asset_id=aid)
|
||||
elif update_missing_tags:
|
||||
with contextlib.suppress(Exception):
|
||||
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
|
||||
|
||||
for s in states:
|
||||
if s["exists"]:
|
||||
survivors.add(os.path.abspath(s["fp"]))
|
||||
|
||||
if stale_state_ids:
|
||||
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
|
||||
if to_set_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_set_verify))
|
||||
.values(needs_verify=True)
|
||||
)
|
||||
if to_clear_verify:
|
||||
sess.execute(
|
||||
sqlalchemy.update(AssetCacheState)
|
||||
.where(AssetCacheState.id.in_(to_clear_verify))
|
||||
.values(needs_verify=False)
|
||||
)
|
||||
sess.commit()
|
||||
return survivors if collect_existing_paths else None
|
||||
return enriched, failed_ids
|
||||
|
||||
794
app/assets/seeder.py
Normal file
794
app/assets/seeder.py
Normal file
@ -0,0 +1,794 @@
|
||||
"""Background asset seeder with thread management and cancellation support."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from typing import Callable
|
||||
|
||||
from app.assets.scanner import (
|
||||
ENRICHMENT_METADATA,
|
||||
ENRICHMENT_STUB,
|
||||
RootType,
|
||||
build_asset_specs,
|
||||
collect_paths_for_roots,
|
||||
enrich_assets_batch,
|
||||
get_all_known_prefixes,
|
||||
get_prefixes_for_root,
|
||||
get_unenriched_assets_for_roots,
|
||||
insert_asset_specs,
|
||||
mark_missing_outside_prefixes_safely,
|
||||
sync_root_safely,
|
||||
)
|
||||
from app.database.db import dependencies_available
|
||||
|
||||
|
||||
class ScanInProgressError(Exception):
|
||||
"""Raised when an operation cannot proceed because a scan is running."""
|
||||
|
||||
|
||||
class State(Enum):
|
||||
"""Seeder state machine states."""
|
||||
|
||||
IDLE = "IDLE"
|
||||
RUNNING = "RUNNING"
|
||||
PAUSED = "PAUSED"
|
||||
CANCELLING = "CANCELLING"
|
||||
|
||||
|
||||
class ScanPhase(Enum):
|
||||
"""Scan phase options."""
|
||||
|
||||
FAST = "fast" # Phase 1: filesystem only (stubs)
|
||||
ENRICH = "enrich" # Phase 2: metadata + hash
|
||||
FULL = "full" # Both phases sequentially
|
||||
|
||||
|
||||
@dataclass
|
||||
class Progress:
|
||||
"""Progress information for a scan operation."""
|
||||
|
||||
scanned: int = 0
|
||||
total: int = 0
|
||||
created: int = 0
|
||||
skipped: int = 0
|
||||
|
||||
|
||||
@dataclass
|
||||
class ScanStatus:
|
||||
"""Current status of the asset seeder."""
|
||||
|
||||
state: State
|
||||
progress: Progress | None
|
||||
errors: list[str] = field(default_factory=list)
|
||||
|
||||
|
||||
ProgressCallback = Callable[[Progress], None]
|
||||
|
||||
|
||||
class _AssetSeeder:
|
||||
"""Background asset scanning manager.
|
||||
|
||||
Spawns ephemeral daemon threads for scanning.
|
||||
Each scan creates a new thread that exits when complete.
|
||||
Use the module-level ``asset_seeder`` instance.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._lock = threading.Lock()
|
||||
self._state = State.IDLE
|
||||
self._progress: Progress | None = None
|
||||
self._last_progress: Progress | None = None
|
||||
self._errors: list[str] = []
|
||||
self._thread: threading.Thread | None = None
|
||||
self._cancel_event = threading.Event()
|
||||
self._run_gate = threading.Event()
|
||||
self._run_gate.set() # Start unpaused (set = running, clear = paused)
|
||||
self._roots: tuple[RootType, ...] = ()
|
||||
self._phase: ScanPhase = ScanPhase.FULL
|
||||
self._compute_hashes: bool = False
|
||||
self._prune_first: bool = False
|
||||
self._progress_callback: ProgressCallback | None = None
|
||||
self._disabled: bool = False
|
||||
|
||||
def disable(self) -> None:
|
||||
"""Disable the asset seeder, preventing any scans from starting."""
|
||||
self._disabled = True
|
||||
logging.info("Asset seeder disabled")
|
||||
|
||||
def is_disabled(self) -> bool:
|
||||
"""Check if the asset seeder is disabled."""
|
||||
return self._disabled
|
||||
|
||||
def start(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
phase: ScanPhase = ScanPhase.FULL,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
prune_first: bool = False,
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start a background scan for the given roots.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan (models, input, output)
|
||||
phase: Scan phase to run (FAST, ENRICH, or FULL for both)
|
||||
progress_callback: Optional callback called with progress updates
|
||||
prune_first: If True, prune orphaned assets before scanning
|
||||
compute_hashes: If True, compute blake3 hashes (slow)
|
||||
|
||||
Returns:
|
||||
True if scan was started, False if already running
|
||||
"""
|
||||
if self._disabled:
|
||||
logging.debug("Asset seeder is disabled, skipping start")
|
||||
return False
|
||||
logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value)
|
||||
with self._lock:
|
||||
if self._state != State.IDLE:
|
||||
logging.info("Asset seeder already running, skipping start")
|
||||
return False
|
||||
self._state = State.RUNNING
|
||||
self._progress = Progress()
|
||||
self._errors = []
|
||||
self._roots = roots
|
||||
self._phase = phase
|
||||
self._prune_first = prune_first
|
||||
self._compute_hashes = compute_hashes
|
||||
self._progress_callback = progress_callback
|
||||
self._cancel_event.clear()
|
||||
self._run_gate.set() # Ensure unpaused when starting
|
||||
self._thread = threading.Thread(
|
||||
target=self._run_scan,
|
||||
name="_AssetSeeder",
|
||||
daemon=True,
|
||||
)
|
||||
self._thread.start()
|
||||
return True
|
||||
|
||||
def start_fast(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
prune_first: bool = False,
|
||||
) -> bool:
|
||||
"""Start a fast scan (phase 1 only) - creates stub records.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
progress_callback: Optional callback for progress updates
|
||||
prune_first: If True, prune orphaned assets before scanning
|
||||
|
||||
Returns:
|
||||
True if scan was started, False if already running
|
||||
"""
|
||||
return self.start(
|
||||
roots=roots,
|
||||
phase=ScanPhase.FAST,
|
||||
progress_callback=progress_callback,
|
||||
prune_first=prune_first,
|
||||
compute_hashes=False,
|
||||
)
|
||||
|
||||
def start_enrich(
|
||||
self,
|
||||
roots: tuple[RootType, ...] = ("models", "input", "output"),
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
compute_hashes: bool = False,
|
||||
) -> bool:
|
||||
"""Start an enrichment scan (phase 2 only) - extracts metadata and hashes.
|
||||
|
||||
Args:
|
||||
roots: Tuple of root types to scan
|
||||
progress_callback: Optional callback for progress updates
|
||||
compute_hashes: If True, compute blake3 hashes
|
||||
|
||||
Returns:
|
||||
True if scan was started, False if already running
|
||||
"""
|
||||
return self.start(
|
||||
roots=roots,
|
||||
phase=ScanPhase.ENRICH,
|
||||
progress_callback=progress_callback,
|
||||
prune_first=False,
|
||||
compute_hashes=compute_hashes,
|
||||
)
|
||||
|
||||
def cancel(self) -> bool:
|
||||
"""Request cancellation of the current scan.
|
||||
|
||||
Returns:
|
||||
True if cancellation was requested, False if not running or paused
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state not in (State.RUNNING, State.PAUSED):
|
||||
return False
|
||||
logging.info("Asset seeder cancelling (was %s)", self._state.value)
|
||||
self._state = State.CANCELLING
|
||||
self._cancel_event.set()
|
||||
self._run_gate.set() # Unblock if paused so thread can exit
|
||||
return True
|
||||
|
||||
def stop(self) -> bool:
|
||||
"""Stop the current scan (alias for cancel).
|
||||
|
||||
Returns:
|
||||
True if stop was requested, False if not running
|
||||
"""
|
||||
return self.cancel()
|
||||
|
||||
def pause(self) -> bool:
|
||||
"""Pause the current scan.
|
||||
|
||||
The scan will complete its current batch before pausing.
|
||||
|
||||
Returns:
|
||||
True if pause was requested, False if not running
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.RUNNING:
|
||||
return False
|
||||
logging.info("Asset seeder pausing")
|
||||
self._state = State.PAUSED
|
||||
self._run_gate.clear()
|
||||
return True
|
||||
|
||||
def resume(self) -> bool:
|
||||
"""Resume a paused scan.
|
||||
|
||||
This is a noop if the scan is not in the PAUSED state
|
||||
|
||||
Returns:
|
||||
True if resumed, False if not paused
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.PAUSED:
|
||||
return False
|
||||
logging.info("Asset seeder resuming")
|
||||
self._state = State.RUNNING
|
||||
self._run_gate.set()
|
||||
self._emit_event("assets.seed.resumed", {})
|
||||
return True
|
||||
|
||||
def restart(
|
||||
self,
|
||||
roots: tuple[RootType, ...] | None = None,
|
||||
phase: ScanPhase | None = None,
|
||||
progress_callback: ProgressCallback | None = None,
|
||||
prune_first: bool | None = None,
|
||||
compute_hashes: bool | None = None,
|
||||
timeout: float = 5.0,
|
||||
) -> bool:
|
||||
"""Cancel any running scan and start a new one.
|
||||
|
||||
Args:
|
||||
roots: Roots to scan (defaults to previous roots)
|
||||
phase: Scan phase (defaults to previous phase)
|
||||
progress_callback: Progress callback (defaults to previous)
|
||||
prune_first: Prune before scan (defaults to previous)
|
||||
compute_hashes: Compute hashes (defaults to previous)
|
||||
timeout: Max seconds to wait for current scan to stop
|
||||
|
||||
Returns:
|
||||
True if new scan was started, False if failed to stop previous
|
||||
"""
|
||||
logging.info("Asset seeder restart requested")
|
||||
with self._lock:
|
||||
prev_roots = self._roots
|
||||
prev_phase = self._phase
|
||||
prev_callback = self._progress_callback
|
||||
prev_prune = self._prune_first
|
||||
prev_hashes = self._compute_hashes
|
||||
|
||||
self.cancel()
|
||||
if not self.wait(timeout=timeout):
|
||||
return False
|
||||
|
||||
cb = progress_callback if progress_callback is not None else prev_callback
|
||||
return self.start(
|
||||
roots=roots if roots is not None else prev_roots,
|
||||
phase=phase if phase is not None else prev_phase,
|
||||
progress_callback=cb,
|
||||
prune_first=prune_first if prune_first is not None else prev_prune,
|
||||
compute_hashes=(
|
||||
compute_hashes if compute_hashes is not None else prev_hashes
|
||||
),
|
||||
)
|
||||
|
||||
def wait(self, timeout: float | None = None) -> bool:
|
||||
"""Wait for the current scan to complete.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait, or None for no timeout
|
||||
|
||||
Returns:
|
||||
True if scan completed, False if timeout expired or no scan running
|
||||
"""
|
||||
with self._lock:
|
||||
thread = self._thread
|
||||
if thread is None:
|
||||
return True
|
||||
thread.join(timeout=timeout)
|
||||
return not thread.is_alive()
|
||||
|
||||
def get_status(self) -> ScanStatus:
|
||||
"""Get the current status and progress of the seeder."""
|
||||
with self._lock:
|
||||
src = self._progress or self._last_progress
|
||||
return ScanStatus(
|
||||
state=self._state,
|
||||
progress=Progress(
|
||||
scanned=src.scanned,
|
||||
total=src.total,
|
||||
created=src.created,
|
||||
skipped=src.skipped,
|
||||
)
|
||||
if src
|
||||
else None,
|
||||
errors=list(self._errors),
|
||||
)
|
||||
|
||||
def shutdown(self, timeout: float = 5.0) -> None:
|
||||
"""Gracefully shutdown: cancel any running scan and wait for thread.
|
||||
|
||||
Args:
|
||||
timeout: Maximum seconds to wait for thread to exit
|
||||
"""
|
||||
self.cancel()
|
||||
self.wait(timeout=timeout)
|
||||
with self._lock:
|
||||
self._thread = None
|
||||
|
||||
def mark_missing_outside_prefixes(self) -> int:
|
||||
"""Mark references as missing when outside all known root prefixes.
|
||||
|
||||
This is a non-destructive soft-delete operation. Assets and their
|
||||
metadata are preserved, but references are flagged as missing.
|
||||
They can be restored if the file reappears in a future scan.
|
||||
|
||||
This operation is decoupled from scanning to prevent partial scans
|
||||
from accidentally marking assets belonging to other roots.
|
||||
|
||||
Should be called explicitly when cleanup is desired, typically after
|
||||
a full scan of all roots or during maintenance.
|
||||
|
||||
Returns:
|
||||
Number of references marked as missing
|
||||
|
||||
Raises:
|
||||
ScanInProgressError: If a scan is currently running
|
||||
"""
|
||||
with self._lock:
|
||||
if self._state != State.IDLE:
|
||||
raise ScanInProgressError(
|
||||
"Cannot mark missing assets while scan is running"
|
||||
)
|
||||
self._state = State.RUNNING
|
||||
|
||||
try:
|
||||
if not dependencies_available():
|
||||
logging.warning(
|
||||
"Database dependencies not available, skipping mark missing"
|
||||
)
|
||||
return 0
|
||||
|
||||
all_prefixes = get_all_known_prefixes()
|
||||
marked = mark_missing_outside_prefixes_safely(all_prefixes)
|
||||
if marked > 0:
|
||||
logging.info("Marked %d references as missing", marked)
|
||||
return marked
|
||||
finally:
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _is_cancelled(self) -> bool:
|
||||
"""Check if cancellation has been requested."""
|
||||
return self._cancel_event.is_set()
|
||||
|
||||
def _is_paused_or_cancelled(self) -> bool:
|
||||
"""Non-blocking check: True if paused or cancelled.
|
||||
|
||||
Use as interrupt_check for I/O-bound work (e.g. hashing) so that
|
||||
file handles are released immediately on pause rather than held
|
||||
open while blocked. The caller is responsible for blocking on
|
||||
_check_pause_and_cancel() afterward.
|
||||
"""
|
||||
return not self._run_gate.is_set() or self._cancel_event.is_set()
|
||||
|
||||
def _check_pause_and_cancel(self) -> bool:
|
||||
"""Block while paused, then check if cancelled.
|
||||
|
||||
Call this at checkpoint locations in scan loops. It will:
|
||||
1. Block indefinitely while paused (until resume or cancel)
|
||||
2. Return True if cancelled, False to continue
|
||||
|
||||
Returns:
|
||||
True if scan should stop, False to continue
|
||||
"""
|
||||
if not self._run_gate.is_set():
|
||||
self._emit_event("assets.seed.paused", {})
|
||||
self._run_gate.wait() # Blocks if paused
|
||||
return self._is_cancelled()
|
||||
|
||||
def _emit_event(self, event_type: str, data: dict) -> None:
|
||||
"""Emit a WebSocket event if server is available."""
|
||||
try:
|
||||
from server import PromptServer
|
||||
|
||||
if hasattr(PromptServer, "instance") and PromptServer.instance:
|
||||
PromptServer.instance.send_sync(event_type, data)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
def _update_progress(
|
||||
self,
|
||||
scanned: int | None = None,
|
||||
total: int | None = None,
|
||||
created: int | None = None,
|
||||
skipped: int | None = None,
|
||||
) -> None:
|
||||
"""Update progress counters (thread-safe)."""
|
||||
callback: ProgressCallback | None = None
|
||||
progress: Progress | None = None
|
||||
|
||||
with self._lock:
|
||||
if self._progress is None:
|
||||
return
|
||||
if scanned is not None:
|
||||
self._progress.scanned = scanned
|
||||
if total is not None:
|
||||
self._progress.total = total
|
||||
if created is not None:
|
||||
self._progress.created = created
|
||||
if skipped is not None:
|
||||
self._progress.skipped = skipped
|
||||
if self._progress_callback:
|
||||
callback = self._progress_callback
|
||||
progress = Progress(
|
||||
scanned=self._progress.scanned,
|
||||
total=self._progress.total,
|
||||
created=self._progress.created,
|
||||
skipped=self._progress.skipped,
|
||||
)
|
||||
|
||||
if callback and progress:
|
||||
try:
|
||||
callback(progress)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
_MAX_ERRORS = 200
|
||||
|
||||
def _add_error(self, message: str) -> None:
|
||||
"""Add an error message (thread-safe), capped at _MAX_ERRORS."""
|
||||
with self._lock:
|
||||
if len(self._errors) < self._MAX_ERRORS:
|
||||
self._errors.append(message)
|
||||
|
||||
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
|
||||
"""Log the directories that will be scanned."""
|
||||
import folder_paths
|
||||
|
||||
for root in roots:
|
||||
if root == "models":
|
||||
logging.info(
|
||||
"Asset scan [models] directory: %s",
|
||||
os.path.abspath(folder_paths.models_dir),
|
||||
)
|
||||
else:
|
||||
prefixes = get_prefixes_for_root(root)
|
||||
if prefixes:
|
||||
logging.info("Asset scan [%s] directories: %s", root, prefixes)
|
||||
|
||||
def _run_scan(self) -> None:
|
||||
"""Main scan loop running in background thread."""
|
||||
t_start = time.perf_counter()
|
||||
roots = self._roots
|
||||
phase = self._phase
|
||||
cancelled = False
|
||||
total_created = 0
|
||||
total_enriched = 0
|
||||
skipped_existing = 0
|
||||
total_paths = 0
|
||||
|
||||
try:
|
||||
if not dependencies_available():
|
||||
self._add_error("Database dependencies not available")
|
||||
self._emit_event(
|
||||
"assets.seed.error",
|
||||
{"message": "Database dependencies not available"},
|
||||
)
|
||||
return
|
||||
|
||||
if self._prune_first:
|
||||
all_prefixes = get_all_known_prefixes()
|
||||
marked = mark_missing_outside_prefixes_safely(all_prefixes)
|
||||
if marked > 0:
|
||||
logging.info("Marked %d refs as missing before scan", marked)
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
logging.info("Asset scan cancelled after pruning phase")
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
self._log_scan_config(roots)
|
||||
|
||||
# Phase 1: Fast scan (stub records)
|
||||
if phase in (ScanPhase.FAST, ScanPhase.FULL):
|
||||
created, skipped, paths = self._run_fast_phase(roots)
|
||||
total_created, skipped_existing, total_paths = created, skipped, paths
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.fast_complete",
|
||||
{
|
||||
"roots": list(roots),
|
||||
"created": total_created,
|
||||
"skipped": skipped_existing,
|
||||
"total": total_paths,
|
||||
},
|
||||
)
|
||||
|
||||
# Phase 2: Enrichment scan (metadata + hashes)
|
||||
if phase in (ScanPhase.ENRICH, ScanPhase.FULL):
|
||||
if self._check_pause_and_cancel():
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
enrich_cancelled, total_enriched = self._run_enrich_phase(roots)
|
||||
|
||||
if enrich_cancelled:
|
||||
cancelled = True
|
||||
return
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.enrich_complete",
|
||||
{
|
||||
"roots": list(roots),
|
||||
"enriched": total_enriched,
|
||||
},
|
||||
)
|
||||
|
||||
elapsed = time.perf_counter() - t_start
|
||||
logging.info(
|
||||
"Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d",
|
||||
roots,
|
||||
phase.value,
|
||||
elapsed,
|
||||
total_created,
|
||||
total_enriched,
|
||||
skipped_existing,
|
||||
)
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.completed",
|
||||
{
|
||||
"phase": phase.value,
|
||||
"total": total_paths,
|
||||
"created": total_created,
|
||||
"enriched": total_enriched,
|
||||
"skipped": skipped_existing,
|
||||
"elapsed": round(elapsed, 3),
|
||||
},
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
self._add_error(f"Scan failed: {e}")
|
||||
logging.exception("Asset scan failed")
|
||||
self._emit_event("assets.seed.error", {"message": str(e)})
|
||||
finally:
|
||||
if cancelled:
|
||||
self._emit_event(
|
||||
"assets.seed.cancelled",
|
||||
{
|
||||
"scanned": self._progress.scanned if self._progress else 0,
|
||||
"total": total_paths,
|
||||
"created": total_created,
|
||||
},
|
||||
)
|
||||
with self._lock:
|
||||
self._last_progress = self._progress
|
||||
self._state = State.IDLE
|
||||
self._progress = None
|
||||
|
||||
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
|
||||
"""Run phase 1: fast scan to create stub records.
|
||||
|
||||
Returns:
|
||||
Tuple of (total_created, skipped_existing, total_paths)
|
||||
"""
|
||||
t_fast_start = time.perf_counter()
|
||||
total_created = 0
|
||||
skipped_existing = 0
|
||||
|
||||
existing_paths: set[str] = set()
|
||||
t_sync = time.perf_counter()
|
||||
for r in roots:
|
||||
if self._check_pause_and_cancel():
|
||||
return total_created, skipped_existing, 0
|
||||
existing_paths.update(sync_root_safely(r))
|
||||
logging.debug(
|
||||
"Fast scan: sync_root phase took %.3fs (%d existing paths)",
|
||||
time.perf_counter() - t_sync,
|
||||
len(existing_paths),
|
||||
)
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
return total_created, skipped_existing, 0
|
||||
|
||||
t_collect = time.perf_counter()
|
||||
paths = collect_paths_for_roots(roots)
|
||||
logging.debug(
|
||||
"Fast scan: collect_paths took %.3fs (%d paths found)",
|
||||
time.perf_counter() - t_collect,
|
||||
len(paths),
|
||||
)
|
||||
total_paths = len(paths)
|
||||
self._update_progress(total=total_paths)
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.started",
|
||||
{"roots": list(roots), "total": total_paths, "phase": "fast"},
|
||||
)
|
||||
|
||||
# Use stub specs (no metadata extraction, no hashing)
|
||||
t_specs = time.perf_counter()
|
||||
specs, tag_pool, skipped_existing = build_asset_specs(
|
||||
paths,
|
||||
existing_paths,
|
||||
enable_metadata_extraction=False,
|
||||
compute_hashes=False,
|
||||
)
|
||||
logging.debug(
|
||||
"Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)",
|
||||
time.perf_counter() - t_specs,
|
||||
len(specs),
|
||||
skipped_existing,
|
||||
)
|
||||
self._update_progress(skipped=skipped_existing)
|
||||
|
||||
if self._check_pause_and_cancel():
|
||||
return total_created, skipped_existing, total_paths
|
||||
|
||||
batch_size = 500
|
||||
last_progress_time = time.perf_counter()
|
||||
progress_interval = 1.0
|
||||
|
||||
for i in range(0, len(specs), batch_size):
|
||||
if self._check_pause_and_cancel():
|
||||
logging.info(
|
||||
"Fast scan cancelled after %d/%d files (created=%d)",
|
||||
i,
|
||||
len(specs),
|
||||
total_created,
|
||||
)
|
||||
return total_created, skipped_existing, total_paths
|
||||
|
||||
batch = specs[i : i + batch_size]
|
||||
batch_tags = {t for spec in batch for t in spec["tags"]}
|
||||
try:
|
||||
created = insert_asset_specs(batch, batch_tags)
|
||||
total_created += created
|
||||
except Exception as e:
|
||||
self._add_error(f"Batch insert failed at offset {i}: {e}")
|
||||
logging.exception("Batch insert failed at offset %d", i)
|
||||
|
||||
scanned = i + len(batch)
|
||||
now = time.perf_counter()
|
||||
self._update_progress(scanned=scanned, created=total_created)
|
||||
|
||||
if now - last_progress_time >= progress_interval:
|
||||
self._emit_event(
|
||||
"assets.seed.progress",
|
||||
{
|
||||
"phase": "fast",
|
||||
"scanned": scanned,
|
||||
"total": len(specs),
|
||||
"created": total_created,
|
||||
},
|
||||
)
|
||||
last_progress_time = now
|
||||
|
||||
self._update_progress(scanned=len(specs), created=total_created)
|
||||
logging.info(
|
||||
"Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)",
|
||||
time.perf_counter() - t_fast_start,
|
||||
total_created,
|
||||
skipped_existing,
|
||||
total_paths,
|
||||
)
|
||||
return total_created, skipped_existing, total_paths
|
||||
|
||||
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
|
||||
"""Run phase 2: enrich existing records with metadata and hashes.
|
||||
|
||||
Returns:
|
||||
Tuple of (cancelled, total_enriched)
|
||||
"""
|
||||
total_enriched = 0
|
||||
batch_size = 100
|
||||
last_progress_time = time.perf_counter()
|
||||
progress_interval = 1.0
|
||||
|
||||
# Get the target enrichment level based on compute_hashes
|
||||
if not self._compute_hashes:
|
||||
target_max_level = ENRICHMENT_STUB
|
||||
else:
|
||||
target_max_level = ENRICHMENT_METADATA
|
||||
|
||||
self._emit_event(
|
||||
"assets.seed.started",
|
||||
{"roots": list(roots), "phase": "enrich"},
|
||||
)
|
||||
|
||||
skip_ids: set[str] = set()
|
||||
consecutive_empty = 0
|
||||
max_consecutive_empty = 3
|
||||
|
||||
# Hash checkpoints survive across batches so interrupted hashes
|
||||
# can be resumed without re-reading the entire file.
|
||||
hash_checkpoints: dict[str, object] = {}
|
||||
|
||||
while True:
|
||||
if self._check_pause_and_cancel():
|
||||
logging.info("Enrich scan cancelled after %d assets", total_enriched)
|
||||
return True, total_enriched
|
||||
|
||||
# Fetch next batch of unenriched assets
|
||||
unenriched = get_unenriched_assets_for_roots(
|
||||
roots,
|
||||
max_level=target_max_level,
|
||||
limit=batch_size,
|
||||
)
|
||||
|
||||
# Filter out previously failed references
|
||||
if skip_ids:
|
||||
unenriched = [r for r in unenriched if r.reference_id not in skip_ids]
|
||||
|
||||
if not unenriched:
|
||||
break
|
||||
|
||||
enriched, failed_ids = enrich_assets_batch(
|
||||
unenriched,
|
||||
extract_metadata=True,
|
||||
compute_hash=self._compute_hashes,
|
||||
interrupt_check=self._is_paused_or_cancelled,
|
||||
hash_checkpoints=hash_checkpoints,
|
||||
)
|
||||
total_enriched += enriched
|
||||
skip_ids.update(failed_ids)
|
||||
|
||||
if enriched == 0:
|
||||
consecutive_empty += 1
|
||||
if consecutive_empty >= max_consecutive_empty:
|
||||
logging.warning(
|
||||
"Enrich phase stopping: %d consecutive batches with no progress (%d skipped)",
|
||||
consecutive_empty,
|
||||
len(skip_ids),
|
||||
)
|
||||
break
|
||||
else:
|
||||
consecutive_empty = 0
|
||||
|
||||
now = time.perf_counter()
|
||||
if now - last_progress_time >= progress_interval:
|
||||
self._emit_event(
|
||||
"assets.seed.progress",
|
||||
{
|
||||
"phase": "enrich",
|
||||
"enriched": total_enriched,
|
||||
},
|
||||
)
|
||||
last_progress_time = now
|
||||
|
||||
return False, total_enriched
|
||||
|
||||
|
||||
asset_seeder = _AssetSeeder()
|
||||
87
app/assets/services/__init__.py
Normal file
87
app/assets/services/__init__.py
Normal file
@ -0,0 +1,87 @@
|
||||
from app.assets.services.asset_management import (
|
||||
asset_exists,
|
||||
delete_asset_reference,
|
||||
get_asset_by_hash,
|
||||
get_asset_detail,
|
||||
list_assets_page,
|
||||
resolve_asset_for_download,
|
||||
set_asset_preview,
|
||||
update_asset_metadata,
|
||||
)
|
||||
from app.assets.services.bulk_ingest import (
|
||||
BulkInsertResult,
|
||||
batch_insert_seed_assets,
|
||||
cleanup_unreferenced_assets,
|
||||
)
|
||||
from app.assets.services.file_utils import (
|
||||
get_mtime_ns,
|
||||
get_size_and_mtime_ns,
|
||||
list_files_recursively,
|
||||
verify_file_unchanged,
|
||||
)
|
||||
from app.assets.services.ingest import (
|
||||
DependencyMissingError,
|
||||
HashMismatchError,
|
||||
create_from_hash,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.database.queries import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
)
|
||||
from app.assets.services.schemas import (
|
||||
AssetData,
|
||||
AssetDetailResult,
|
||||
AssetSummaryData,
|
||||
DownloadResolutionResult,
|
||||
IngestResult,
|
||||
ListAssetsResult,
|
||||
ReferenceData,
|
||||
RegisterAssetResult,
|
||||
TagUsage,
|
||||
UploadResult,
|
||||
UserMetadata,
|
||||
)
|
||||
from app.assets.services.tagging import (
|
||||
apply_tags,
|
||||
list_tags,
|
||||
remove_tags,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"AddTagsResult",
|
||||
"AssetData",
|
||||
"AssetDetailResult",
|
||||
"AssetSummaryData",
|
||||
"ReferenceData",
|
||||
"BulkInsertResult",
|
||||
"DependencyMissingError",
|
||||
"DownloadResolutionResult",
|
||||
"HashMismatchError",
|
||||
"IngestResult",
|
||||
"ListAssetsResult",
|
||||
"RegisterAssetResult",
|
||||
"RemoveTagsResult",
|
||||
"TagUsage",
|
||||
"UploadResult",
|
||||
"UserMetadata",
|
||||
"apply_tags",
|
||||
"asset_exists",
|
||||
"batch_insert_seed_assets",
|
||||
"create_from_hash",
|
||||
"delete_asset_reference",
|
||||
"get_asset_by_hash",
|
||||
"get_asset_detail",
|
||||
"get_mtime_ns",
|
||||
"get_size_and_mtime_ns",
|
||||
"list_assets_page",
|
||||
"list_files_recursively",
|
||||
"list_tags",
|
||||
"cleanup_unreferenced_assets",
|
||||
"remove_tags",
|
||||
"resolve_asset_for_download",
|
||||
"set_asset_preview",
|
||||
"update_asset_metadata",
|
||||
"upload_from_temp_path",
|
||||
"verify_file_unchanged",
|
||||
]
|
||||
367
app/assets/services/asset_management.py
Normal file
367
app/assets/services/asset_management.py
Normal file
@ -0,0 +1,367 @@
|
||||
import contextlib
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Sequence
|
||||
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
reference_exists_for_asset_id,
|
||||
delete_reference_by_id,
|
||||
fetch_reference_and_asset,
|
||||
soft_delete_reference_by_id,
|
||||
fetch_reference_asset_and_tags,
|
||||
get_asset_by_hash as queries_get_asset_by_hash,
|
||||
get_reference_by_id,
|
||||
get_reference_with_owner_check,
|
||||
list_references_page,
|
||||
list_all_file_paths_by_asset_id,
|
||||
list_references_by_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_preview,
|
||||
set_reference_tags,
|
||||
update_asset_hash_and_mime,
|
||||
update_reference_access_time,
|
||||
update_reference_name,
|
||||
update_reference_updated_at,
|
||||
)
|
||||
from app.assets.helpers import select_best_live_path
|
||||
from app.assets.services.path_utils import compute_relative_filename
|
||||
from app.assets.services.schemas import (
|
||||
AssetData,
|
||||
AssetDetailResult,
|
||||
AssetSummaryData,
|
||||
DownloadResolutionResult,
|
||||
ListAssetsResult,
|
||||
UserMetadata,
|
||||
extract_asset_data,
|
||||
extract_reference_data,
|
||||
)
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def get_asset_detail(
|
||||
reference_id: str,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult | None:
|
||||
with create_session() as session:
|
||||
result = fetch_reference_asset_and_tags(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
if not result:
|
||||
return None
|
||||
|
||||
ref, asset, tags = result
|
||||
return AssetDetailResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tags,
|
||||
)
|
||||
|
||||
|
||||
def update_asset_metadata(
|
||||
reference_id: str,
|
||||
name: str | None = None,
|
||||
tags: Sequence[str] | None = None,
|
||||
user_metadata: UserMetadata = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
touched = False
|
||||
if name is not None and name != ref.name:
|
||||
update_reference_name(session, reference_id=reference_id, name=name)
|
||||
touched = True
|
||||
|
||||
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||
|
||||
new_meta: dict | None = None
|
||||
if user_metadata is not None:
|
||||
new_meta = dict(user_metadata)
|
||||
elif computed_filename:
|
||||
current_meta = ref.user_metadata or {}
|
||||
if current_meta.get("filename") != computed_filename:
|
||||
new_meta = dict(current_meta)
|
||||
|
||||
if new_meta is not None:
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
set_reference_metadata(
|
||||
session, reference_id=reference_id, user_metadata=new_meta
|
||||
)
|
||||
touched = True
|
||||
|
||||
if tags is not None:
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if mime_type is not None:
|
||||
updated = update_asset_hash_and_mime(
|
||||
session, asset_id=ref.asset_id, mime_type=mime_type
|
||||
)
|
||||
if updated:
|
||||
touched = True
|
||||
|
||||
if preview_id is not None:
|
||||
set_reference_preview(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
preview_reference_id=preview_id,
|
||||
)
|
||||
touched = True
|
||||
|
||||
if touched and user_metadata is None:
|
||||
update_reference_updated_at(session, reference_id=reference_id)
|
||||
|
||||
result = fetch_reference_asset_and_tags(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("State changed during update")
|
||||
|
||||
ref, asset, tag_list = result
|
||||
detail = AssetDetailResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_list,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
def delete_asset_reference(
|
||||
reference_id: str,
|
||||
owner_id: str,
|
||||
delete_content_if_orphan: bool = True,
|
||||
) -> bool:
|
||||
with create_session() as session:
|
||||
if not delete_content_if_orphan:
|
||||
# Soft delete: mark the reference as deleted but keep everything
|
||||
deleted = soft_delete_reference_by_id(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
session.commit()
|
||||
return deleted
|
||||
|
||||
ref_row = get_reference_by_id(session, reference_id=reference_id)
|
||||
asset_id = ref_row.asset_id if ref_row else None
|
||||
file_path = ref_row.file_path if ref_row else None
|
||||
|
||||
deleted = delete_reference_by_id(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not deleted:
|
||||
session.commit()
|
||||
return False
|
||||
|
||||
if not asset_id:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
still_exists = reference_exists_for_asset_id(session, asset_id=asset_id)
|
||||
if still_exists:
|
||||
session.commit()
|
||||
return True
|
||||
|
||||
# Orphaned asset - gather ALL file paths (including
|
||||
# soft-deleted / missing refs) so their on-disk files get cleaned up.
|
||||
file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
|
||||
# Also include the just-deleted file path
|
||||
if file_path:
|
||||
file_paths.append(file_path)
|
||||
|
||||
asset_row = session.get(Asset, asset_id)
|
||||
if asset_row is not None:
|
||||
session.delete(asset_row)
|
||||
|
||||
session.commit()
|
||||
|
||||
# Delete files after commit
|
||||
for p in file_paths:
|
||||
with contextlib.suppress(Exception):
|
||||
if p and os.path.isfile(p):
|
||||
os.remove(p)
|
||||
|
||||
return True
|
||||
|
||||
|
||||
def set_asset_preview(
|
||||
reference_id: str,
|
||||
preview_reference_id: str | None = None,
|
||||
owner_id: str = "",
|
||||
) -> AssetDetailResult:
|
||||
with create_session() as session:
|
||||
get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
set_reference_preview(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
preview_reference_id=preview_reference_id,
|
||||
)
|
||||
|
||||
result = fetch_reference_asset_and_tags(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not result:
|
||||
raise RuntimeError("State changed during preview update")
|
||||
|
||||
ref, asset, tags = result
|
||||
detail = AssetDetailResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return detail
|
||||
|
||||
|
||||
def asset_exists(asset_hash: str) -> bool:
|
||||
with create_session() as session:
|
||||
return asset_exists_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
|
||||
def get_asset_by_hash(asset_hash: str) -> AssetData | None:
|
||||
with create_session() as session:
|
||||
asset = queries_get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
return extract_asset_data(asset)
|
||||
|
||||
|
||||
def list_assets_page(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 20,
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
) -> ListAssetsResult:
|
||||
with create_session() as session:
|
||||
refs, tag_map, total = list_references_page(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
|
||||
items: list[AssetSummaryData] = []
|
||||
for ref in refs:
|
||||
items.append(
|
||||
AssetSummaryData(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(ref.asset),
|
||||
tags=tag_map.get(ref.id, []),
|
||||
)
|
||||
)
|
||||
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
|
||||
|
||||
def resolve_hash_to_path(
|
||||
asset_hash: str,
|
||||
owner_id: str = "",
|
||||
) -> DownloadResolutionResult | None:
|
||||
"""Resolve a blake3 hash to an on-disk file path.
|
||||
|
||||
Only references visible to *owner_id* are considered (owner-less
|
||||
references are always visible).
|
||||
|
||||
Returns a DownloadResolutionResult with abs_path, content_type, and
|
||||
download_name, or None if no asset or live path is found.
|
||||
"""
|
||||
with create_session() as session:
|
||||
asset = queries_get_asset_by_hash(session, asset_hash)
|
||||
if not asset:
|
||||
return None
|
||||
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
||||
visible = [
|
||||
r for r in refs
|
||||
if r.owner_id == "" or r.owner_id == owner_id
|
||||
]
|
||||
abs_path = select_best_live_path(visible)
|
||||
if not abs_path:
|
||||
return None
|
||||
display_name = os.path.basename(abs_path)
|
||||
for ref in visible:
|
||||
if ref.file_path == abs_path and ref.name:
|
||||
display_name = ref.name
|
||||
break
|
||||
ctype = (
|
||||
asset.mime_type
|
||||
or mimetypes.guess_type(display_name)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
return DownloadResolutionResult(
|
||||
abs_path=abs_path,
|
||||
content_type=ctype,
|
||||
download_name=display_name,
|
||||
)
|
||||
|
||||
|
||||
def resolve_asset_for_download(
|
||||
reference_id: str,
|
||||
owner_id: str = "",
|
||||
) -> DownloadResolutionResult:
|
||||
with create_session() as session:
|
||||
pair = fetch_reference_and_asset(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise ValueError(f"AssetReference {reference_id} not found")
|
||||
|
||||
ref, asset = pair
|
||||
|
||||
# For references with file_path, use that directly
|
||||
if ref.file_path and os.path.isfile(ref.file_path):
|
||||
abs_path = ref.file_path
|
||||
else:
|
||||
# For API-created refs without file_path, find a path from other refs
|
||||
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
||||
abs_path = select_best_live_path(refs)
|
||||
if not abs_path:
|
||||
raise FileNotFoundError(
|
||||
f"No live path for AssetReference {reference_id} "
|
||||
f"(asset id={asset.id}, name={ref.name})"
|
||||
)
|
||||
|
||||
# Capture ORM attributes before commit (commit expires loaded objects)
|
||||
ref_name = ref.name
|
||||
asset_mime = asset.mime_type
|
||||
|
||||
update_reference_access_time(session, reference_id=reference_id)
|
||||
session.commit()
|
||||
|
||||
ctype = (
|
||||
asset_mime
|
||||
or mimetypes.guess_type(ref_name or abs_path)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
download_name = ref_name or os.path.basename(abs_path)
|
||||
return DownloadResolutionResult(
|
||||
abs_path=abs_path,
|
||||
content_type=ctype,
|
||||
download_name=download_name,
|
||||
)
|
||||
280
app/assets/services/bulk_ingest.py
Normal file
280
app/assets/services/bulk_ingest.py
Normal file
@ -0,0 +1,280 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import uuid
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.queries import (
|
||||
bulk_insert_assets,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
bulk_insert_tags_and_meta,
|
||||
delete_assets_by_ids,
|
||||
get_existing_asset_ids,
|
||||
get_reference_ids_by_ids,
|
||||
get_references_by_paths_and_asset_ids,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
restore_references_by_paths,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from app.assets.services.metadata_extract import ExtractedMetadata
|
||||
|
||||
|
||||
class SeedAssetSpec(TypedDict):
|
||||
"""Spec for seeding an asset from filesystem."""
|
||||
|
||||
abs_path: str
|
||||
size_bytes: int
|
||||
mtime_ns: int
|
||||
info_name: str
|
||||
tags: list[str]
|
||||
fname: str
|
||||
metadata: ExtractedMetadata | None
|
||||
hash: str | None
|
||||
mime_type: str | None
|
||||
|
||||
|
||||
class AssetRow(TypedDict):
|
||||
"""Row data for inserting an Asset."""
|
||||
|
||||
id: str
|
||||
hash: str | None
|
||||
size_bytes: int
|
||||
mime_type: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
class ReferenceRow(TypedDict):
|
||||
"""Row data for inserting an AssetReference."""
|
||||
|
||||
id: str
|
||||
asset_id: str
|
||||
file_path: str
|
||||
mtime_ns: int
|
||||
owner_id: str
|
||||
name: str
|
||||
preview_id: str | None
|
||||
user_metadata: dict[str, Any] | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
last_access_time: datetime
|
||||
|
||||
|
||||
class TagRow(TypedDict):
|
||||
"""Row data for inserting a Tag."""
|
||||
|
||||
asset_reference_id: str
|
||||
tag_name: str
|
||||
origin: str
|
||||
added_at: datetime
|
||||
|
||||
|
||||
class MetadataRow(TypedDict):
|
||||
"""Row data for inserting asset metadata."""
|
||||
|
||||
asset_reference_id: str
|
||||
key: str
|
||||
ordinal: int
|
||||
val_str: str | None
|
||||
val_num: float | None
|
||||
val_bool: bool | None
|
||||
val_json: dict[str, Any] | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class BulkInsertResult:
|
||||
"""Result of bulk asset insertion."""
|
||||
|
||||
inserted_refs: int
|
||||
won_paths: int
|
||||
lost_paths: int
|
||||
|
||||
|
||||
def batch_insert_seed_assets(
|
||||
session: Session,
|
||||
specs: list[SeedAssetSpec],
|
||||
owner_id: str = "",
|
||||
) -> BulkInsertResult:
|
||||
"""Seed assets from filesystem specs in batch.
|
||||
|
||||
Each spec is a dict with keys:
|
||||
- abs_path: str
|
||||
- size_bytes: int
|
||||
- mtime_ns: int
|
||||
- info_name: str
|
||||
- tags: list[str]
|
||||
- fname: Optional[str]
|
||||
|
||||
This function orchestrates:
|
||||
1. Insert seed Assets (hash=NULL)
|
||||
2. Claim references with ON CONFLICT DO NOTHING on file_path
|
||||
3. Query to find winners (paths where our asset_id was inserted)
|
||||
4. Delete Assets for losers (path already claimed by another asset)
|
||||
5. Insert tags and metadata for successfully inserted references
|
||||
|
||||
Returns:
|
||||
BulkInsertResult with inserted_refs, won_paths, lost_paths
|
||||
"""
|
||||
if not specs:
|
||||
return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0)
|
||||
|
||||
current_time = get_utc_now()
|
||||
asset_rows: list[AssetRow] = []
|
||||
reference_rows: list[ReferenceRow] = []
|
||||
path_to_asset_id: dict[str, str] = {}
|
||||
asset_id_to_ref_data: dict[str, dict] = {}
|
||||
absolute_path_list: list[str] = []
|
||||
|
||||
for spec in specs:
|
||||
absolute_path = os.path.abspath(spec["abs_path"])
|
||||
asset_id = str(uuid.uuid4())
|
||||
reference_id = str(uuid.uuid4())
|
||||
absolute_path_list.append(absolute_path)
|
||||
path_to_asset_id[absolute_path] = asset_id
|
||||
|
||||
mime_type = spec.get("mime_type")
|
||||
asset_rows.append(
|
||||
{
|
||||
"id": asset_id,
|
||||
"hash": spec.get("hash"),
|
||||
"size_bytes": spec["size_bytes"],
|
||||
"mime_type": mime_type,
|
||||
"created_at": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Build user_metadata from extracted metadata or fallback to filename
|
||||
extracted_metadata = spec.get("metadata")
|
||||
if extracted_metadata:
|
||||
user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata()
|
||||
elif spec["fname"]:
|
||||
user_metadata = {"filename": spec["fname"]}
|
||||
else:
|
||||
user_metadata = None
|
||||
|
||||
reference_rows.append(
|
||||
{
|
||||
"id": reference_id,
|
||||
"asset_id": asset_id,
|
||||
"file_path": absolute_path,
|
||||
"mtime_ns": spec["mtime_ns"],
|
||||
"owner_id": owner_id,
|
||||
"name": spec["info_name"],
|
||||
"preview_id": None,
|
||||
"user_metadata": user_metadata,
|
||||
"created_at": current_time,
|
||||
"updated_at": current_time,
|
||||
"last_access_time": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
asset_id_to_ref_data[asset_id] = {
|
||||
"reference_id": reference_id,
|
||||
"tags": spec["tags"],
|
||||
"filename": spec["fname"],
|
||||
"extracted_metadata": extracted_metadata,
|
||||
}
|
||||
|
||||
bulk_insert_assets(session, asset_rows)
|
||||
|
||||
# Filter reference rows to only those whose assets were actually inserted
|
||||
# (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING)
|
||||
inserted_asset_ids = get_existing_asset_ids(
|
||||
session, [r["asset_id"] for r in reference_rows]
|
||||
)
|
||||
reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids]
|
||||
|
||||
bulk_insert_references_ignore_conflicts(session, reference_rows)
|
||||
restore_references_by_paths(session, absolute_path_list)
|
||||
winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id)
|
||||
|
||||
inserted_paths = {
|
||||
path
|
||||
for path in absolute_path_list
|
||||
if path_to_asset_id[path] in inserted_asset_ids
|
||||
}
|
||||
losing_paths = inserted_paths - winning_paths
|
||||
lost_asset_ids = [path_to_asset_id[path] for path in losing_paths]
|
||||
|
||||
if lost_asset_ids:
|
||||
delete_assets_by_ids(session, lost_asset_ids)
|
||||
|
||||
if not winning_paths:
|
||||
return BulkInsertResult(
|
||||
inserted_refs=0,
|
||||
won_paths=0,
|
||||
lost_paths=len(losing_paths),
|
||||
)
|
||||
|
||||
# Get reference IDs for winners
|
||||
winning_ref_ids = [
|
||||
asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"]
|
||||
for path in winning_paths
|
||||
]
|
||||
inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids)
|
||||
|
||||
tag_rows: list[TagRow] = []
|
||||
metadata_rows: list[MetadataRow] = []
|
||||
|
||||
if inserted_ref_ids:
|
||||
for path in winning_paths:
|
||||
asset_id = path_to_asset_id[path]
|
||||
ref_data = asset_id_to_ref_data[asset_id]
|
||||
ref_id = ref_data["reference_id"]
|
||||
|
||||
if ref_id not in inserted_ref_ids:
|
||||
continue
|
||||
|
||||
for tag in ref_data["tags"]:
|
||||
tag_rows.append(
|
||||
{
|
||||
"asset_reference_id": ref_id,
|
||||
"tag_name": tag,
|
||||
"origin": "automatic",
|
||||
"added_at": current_time,
|
||||
}
|
||||
)
|
||||
|
||||
# Use extracted metadata for meta rows if available
|
||||
extracted_metadata = ref_data.get("extracted_metadata")
|
||||
if extracted_metadata:
|
||||
metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id))
|
||||
elif ref_data["filename"]:
|
||||
# Fallback: just store filename
|
||||
metadata_rows.append(
|
||||
{
|
||||
"asset_reference_id": ref_id,
|
||||
"key": "filename",
|
||||
"ordinal": 0,
|
||||
"val_str": ref_data["filename"],
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
}
|
||||
)
|
||||
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
|
||||
|
||||
return BulkInsertResult(
|
||||
inserted_refs=len(inserted_ref_ids),
|
||||
won_paths=len(winning_paths),
|
||||
lost_paths=len(losing_paths),
|
||||
)
|
||||
|
||||
|
||||
def cleanup_unreferenced_assets(session: Session) -> int:
|
||||
"""Hard-delete unhashed assets with no active references.
|
||||
|
||||
This is a destructive operation intended for explicit cleanup.
|
||||
Only deletes assets where hash=None and all references are missing.
|
||||
|
||||
Returns:
|
||||
Number of assets deleted
|
||||
"""
|
||||
unreferenced_ids = get_unreferenced_unhashed_asset_ids(session)
|
||||
return delete_assets_by_ids(session, unreferenced_ids)
|
||||
70
app/assets/services/file_utils.py
Normal file
70
app/assets/services/file_utils.py
Normal file
@ -0,0 +1,70 @@
|
||||
import os
|
||||
|
||||
|
||||
def get_mtime_ns(stat_result: os.stat_result) -> int:
|
||||
"""Extract mtime in nanoseconds from a stat result."""
|
||||
return getattr(
|
||||
stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)
|
||||
)
|
||||
|
||||
|
||||
def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]:
|
||||
"""Get file size in bytes and mtime in nanoseconds."""
|
||||
st = os.stat(path, follow_symlinks=follow_symlinks)
|
||||
return st.st_size, get_mtime_ns(st)
|
||||
|
||||
|
||||
def verify_file_unchanged(
|
||||
mtime_db: int | None,
|
||||
size_db: int | None,
|
||||
stat_result: os.stat_result,
|
||||
) -> bool:
|
||||
"""Check if a file is unchanged based on mtime and size.
|
||||
|
||||
Returns True if the file's mtime and size match the database values.
|
||||
Returns False if mtime_db is None or values don't match.
|
||||
|
||||
size_db=None means don't check size; 0 is a valid recorded size.
|
||||
"""
|
||||
if mtime_db is None:
|
||||
return False
|
||||
actual_mtime_ns = get_mtime_ns(stat_result)
|
||||
if int(mtime_db) != int(actual_mtime_ns):
|
||||
return False
|
||||
if size_db is not None:
|
||||
return int(stat_result.st_size) == int(size_db)
|
||||
return True
|
||||
|
||||
|
||||
def is_visible(name: str) -> bool:
|
||||
"""Return True if a file or directory name is visible (not hidden)."""
|
||||
return not name.startswith(".")
|
||||
|
||||
|
||||
def list_files_recursively(base_dir: str) -> list[str]:
|
||||
"""Recursively list all files in a directory, following symlinks."""
|
||||
out: list[str] = []
|
||||
base_abs = os.path.abspath(base_dir)
|
||||
if not os.path.isdir(base_abs):
|
||||
return out
|
||||
# Track seen real directory identities to prevent circular symlink loops
|
||||
seen_dirs: set[tuple[int, int]] = set()
|
||||
for dirpath, subdirs, filenames in os.walk(
|
||||
base_abs, topdown=True, followlinks=True
|
||||
):
|
||||
try:
|
||||
st = os.stat(dirpath)
|
||||
dir_id = (st.st_dev, st.st_ino)
|
||||
except OSError:
|
||||
subdirs.clear()
|
||||
continue
|
||||
if dir_id in seen_dirs:
|
||||
subdirs.clear()
|
||||
continue
|
||||
seen_dirs.add(dir_id)
|
||||
subdirs[:] = [d for d in subdirs if is_visible(d)]
|
||||
for name in filenames:
|
||||
if not is_visible(name):
|
||||
continue
|
||||
out.append(os.path.abspath(os.path.join(dirpath, name)))
|
||||
return out
|
||||
99
app/assets/services/hashing.py
Normal file
99
app/assets/services/hashing.py
Normal file
@ -0,0 +1,99 @@
|
||||
import io
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from dataclasses import dataclass
|
||||
from typing import IO, Any, Callable, Iterator
|
||||
import logging
|
||||
|
||||
try:
|
||||
from blake3 import blake3
|
||||
except ModuleNotFoundError:
|
||||
logging.warning("WARNING: blake3 package not installed")
|
||||
|
||||
DEFAULT_CHUNK = 8 * 1024 * 1024
|
||||
|
||||
InterruptCheck = Callable[[], bool]
|
||||
|
||||
|
||||
@dataclass
|
||||
class HashCheckpoint:
|
||||
"""Saved state for resuming an interrupted hash computation."""
|
||||
|
||||
bytes_processed: int
|
||||
hasher: Any # blake3 hasher instance
|
||||
mtime_ns: int = 0
|
||||
file_size: int = 0
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]:
|
||||
"""Yield (file_object, is_path) with appropriate setup/teardown."""
|
||||
if hasattr(fp, "read"):
|
||||
seekable = getattr(fp, "seekable", lambda: False)()
|
||||
orig_pos = None
|
||||
if seekable:
|
||||
try:
|
||||
orig_pos = fp.tell()
|
||||
if orig_pos != 0:
|
||||
fp.seek(0)
|
||||
except io.UnsupportedOperation:
|
||||
orig_pos = None
|
||||
try:
|
||||
yield fp, False
|
||||
finally:
|
||||
if orig_pos is not None:
|
||||
fp.seek(orig_pos)
|
||||
else:
|
||||
with open(os.fspath(fp), "rb") as f:
|
||||
yield f, True
|
||||
|
||||
|
||||
def compute_blake3_hash(
|
||||
fp: str | IO[bytes],
|
||||
chunk_size: int = DEFAULT_CHUNK,
|
||||
interrupt_check: InterruptCheck | None = None,
|
||||
checkpoint: HashCheckpoint | None = None,
|
||||
) -> tuple[str | None, HashCheckpoint | None]:
|
||||
"""Compute BLAKE3 hash of a file, with optional checkpoint support.
|
||||
|
||||
Args:
|
||||
fp: File path or file-like object
|
||||
chunk_size: Size of chunks to read at a time
|
||||
interrupt_check: Optional callable that returns True if the operation
|
||||
should be interrupted (e.g. paused or cancelled). Must be
|
||||
non-blocking so file handles are released immediately. Checked
|
||||
between chunk reads.
|
||||
checkpoint: Optional checkpoint to resume from (file paths only)
|
||||
|
||||
Returns:
|
||||
Tuple of (hex_digest, None) on completion, or
|
||||
(None, checkpoint) on interruption (file paths only), or
|
||||
(None, None) on interruption of a file object
|
||||
"""
|
||||
if chunk_size <= 0:
|
||||
chunk_size = DEFAULT_CHUNK
|
||||
|
||||
with _open_for_hashing(fp) as (f, is_path):
|
||||
if checkpoint is not None and is_path:
|
||||
f.seek(checkpoint.bytes_processed)
|
||||
h = checkpoint.hasher
|
||||
bytes_processed = checkpoint.bytes_processed
|
||||
else:
|
||||
h = blake3()
|
||||
bytes_processed = 0
|
||||
|
||||
while True:
|
||||
if interrupt_check is not None and interrupt_check():
|
||||
if is_path:
|
||||
return None, HashCheckpoint(
|
||||
bytes_processed=bytes_processed,
|
||||
hasher=h,
|
||||
)
|
||||
return None, None
|
||||
chunk = f.read(chunk_size)
|
||||
if not chunk:
|
||||
break
|
||||
h.update(chunk)
|
||||
bytes_processed += len(chunk)
|
||||
|
||||
return h.hexdigest(), None
|
||||
463
app/assets/services/ingest.py
Normal file
463
app/assets/services/ingest.py
Normal file
@ -0,0 +1,463 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
from typing import Any, Sequence
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
import app.assets.services.hashing as hashing
|
||||
from app.assets.database.queries import (
|
||||
add_tags_to_reference,
|
||||
fetch_reference_and_asset,
|
||||
get_asset_by_hash,
|
||||
get_reference_by_file_path,
|
||||
get_reference_tags,
|
||||
get_or_create_reference,
|
||||
reference_exists,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_tags,
|
||||
update_asset_hash_and_mime,
|
||||
upsert_asset,
|
||||
upsert_reference,
|
||||
validate_tags_exist,
|
||||
)
|
||||
from app.assets.helpers import normalize_tags
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_name_and_tags_from_asset_path,
|
||||
resolve_destination_from_tags,
|
||||
validate_path_within_base,
|
||||
)
|
||||
from app.assets.services.schemas import (
|
||||
IngestResult,
|
||||
RegisterAssetResult,
|
||||
UploadResult,
|
||||
UserMetadata,
|
||||
extract_asset_data,
|
||||
extract_reference_data,
|
||||
)
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def _ingest_file_from_path(
|
||||
abs_path: str,
|
||||
asset_hash: str,
|
||||
size_bytes: int,
|
||||
mtime_ns: int,
|
||||
mime_type: str | None = None,
|
||||
info_name: str | None = None,
|
||||
owner_id: str = "",
|
||||
preview_id: str | None = None,
|
||||
user_metadata: UserMetadata = None,
|
||||
tags: Sequence[str] = (),
|
||||
tag_origin: str = "manual",
|
||||
require_existing_tags: bool = False,
|
||||
) -> IngestResult:
|
||||
locator = os.path.abspath(abs_path)
|
||||
user_metadata = user_metadata or {}
|
||||
|
||||
asset_created = False
|
||||
asset_updated = False
|
||||
ref_created = False
|
||||
ref_updated = False
|
||||
reference_id: str | None = None
|
||||
|
||||
with create_session() as session:
|
||||
if preview_id:
|
||||
if not reference_exists(session, preview_id):
|
||||
preview_id = None
|
||||
|
||||
asset, asset_created, asset_updated = upsert_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=size_bytes,
|
||||
mime_type=mime_type,
|
||||
)
|
||||
|
||||
ref_created, ref_updated = upsert_reference(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
file_path=locator,
|
||||
name=info_name or os.path.basename(locator),
|
||||
mtime_ns=mtime_ns,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
# Get the reference we just created/updated
|
||||
ref = get_reference_by_file_path(session, locator)
|
||||
if ref:
|
||||
reference_id = ref.id
|
||||
|
||||
if preview_id and ref.preview_id != preview_id:
|
||||
ref.preview_id = preview_id
|
||||
|
||||
norm = normalize_tags(list(tags))
|
||||
if norm:
|
||||
if require_existing_tags:
|
||||
validate_tags_exist(session, norm)
|
||||
add_tags_to_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=norm,
|
||||
origin=tag_origin,
|
||||
create_if_missing=not require_existing_tags,
|
||||
)
|
||||
|
||||
_update_metadata_with_filename(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
file_path=ref.file_path,
|
||||
current_metadata=ref.user_metadata,
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
|
||||
|
||||
session.commit()
|
||||
|
||||
return IngestResult(
|
||||
asset_created=asset_created,
|
||||
asset_updated=asset_updated,
|
||||
ref_created=ref_created,
|
||||
ref_updated=ref_updated,
|
||||
reference_id=reference_id,
|
||||
)
|
||||
|
||||
|
||||
def _register_existing_asset(
|
||||
asset_hash: str,
|
||||
name: str,
|
||||
user_metadata: UserMetadata = None,
|
||||
tags: list[str] | None = None,
|
||||
tag_origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> RegisterAssetResult:
|
||||
user_metadata = user_metadata or {}
|
||||
|
||||
with create_session() as session:
|
||||
asset = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
if not asset:
|
||||
raise ValueError(f"No asset with hash {asset_hash}")
|
||||
|
||||
if mime_type and not asset.mime_type:
|
||||
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
|
||||
|
||||
if preview_id:
|
||||
if not reference_exists(session, preview_id):
|
||||
preview_id = None
|
||||
|
||||
ref, ref_created = get_or_create_reference(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
|
||||
if not ref_created:
|
||||
if preview_id and ref.preview_id != preview_id:
|
||||
ref.preview_id = preview_id
|
||||
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
result = RegisterAssetResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created=False,
|
||||
)
|
||||
session.commit()
|
||||
return result
|
||||
|
||||
new_meta = dict(user_metadata)
|
||||
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta:
|
||||
set_reference_metadata(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_reference_tags(
|
||||
session,
|
||||
reference_id=ref.id,
|
||||
tags=tags,
|
||||
origin=tag_origin,
|
||||
)
|
||||
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
session.refresh(ref)
|
||||
result = RegisterAssetResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created=True,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
|
||||
def _update_metadata_with_filename(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
file_path: str | None,
|
||||
current_metadata: dict | None,
|
||||
user_metadata: dict[str, Any],
|
||||
) -> None:
|
||||
computed_filename = compute_relative_filename(file_path) if file_path else None
|
||||
|
||||
current_meta = current_metadata or {}
|
||||
new_meta = dict(current_meta)
|
||||
for k, v in user_metadata.items():
|
||||
new_meta[k] = v
|
||||
if computed_filename:
|
||||
new_meta["filename"] = computed_filename
|
||||
|
||||
if new_meta != current_meta:
|
||||
set_reference_metadata(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
|
||||
def _sanitize_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
return n if n else fallback
|
||||
|
||||
|
||||
class HashMismatchError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class DependencyMissingError(Exception):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
super().__init__(message)
|
||||
|
||||
|
||||
def upload_from_temp_path(
|
||||
temp_path: str,
|
||||
name: str | None = None,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
client_filename: str | None = None,
|
||||
owner_id: str = "",
|
||||
expected_hash: str | None = None,
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> UploadResult:
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(temp_path)
|
||||
except ImportError as e:
|
||||
raise DependencyMissingError(str(e))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash uploaded file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
if expected_hash and asset_hash != expected_hash.strip().lower():
|
||||
raise HashMismatchError("Uploaded file hash does not match provided hash.")
|
||||
|
||||
with create_session() as session:
|
||||
existing = get_asset_by_hash(session, asset_hash=asset_hash)
|
||||
|
||||
if existing is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
if temp_path and os.path.exists(temp_path):
|
||||
os.remove(temp_path)
|
||||
|
||||
display_name = _sanitize_filename(name or client_filename, fallback=digest)
|
||||
result = _register_existing_asset(
|
||||
asset_hash=asset_hash,
|
||||
name=display_name,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
mime_type=mime_type,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
asset=result.asset,
|
||||
tags=result.tags,
|
||||
created_new=False,
|
||||
)
|
||||
|
||||
if not tags:
|
||||
raise ValueError("tags are required for new asset uploads")
|
||||
base_dir, subdirs = resolve_destination_from_tags(tags)
|
||||
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
|
||||
os.makedirs(dest_dir, exist_ok=True)
|
||||
|
||||
src_for_ext = (client_filename or name or "").strip()
|
||||
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
|
||||
ext = _ext if 0 < len(_ext) <= 16 else ""
|
||||
hashed_basename = f"{digest}{ext}"
|
||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||
validate_path_within_base(dest_abs, base_dir)
|
||||
|
||||
content_type = mime_type or (
|
||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
try:
|
||||
os.replace(temp_path, dest_abs)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to move uploaded file into place: {e}")
|
||||
|
||||
try:
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs)
|
||||
except OSError as e:
|
||||
raise RuntimeError(f"failed to stat destination file: {e}")
|
||||
|
||||
ingest_result = _ingest_file_from_path(
|
||||
asset_hash=asset_hash,
|
||||
abs_path=dest_abs,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
||||
owner_id=owner_id,
|
||||
preview_id=preview_id,
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags,
|
||||
tag_origin="manual",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
reference_id = ingest_result.reference_id
|
||||
if not reference_id:
|
||||
raise RuntimeError("failed to create asset reference")
|
||||
|
||||
with create_session() as session:
|
||||
pair = fetch_reference_and_asset(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
ref, asset = pair
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
|
||||
return UploadResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created_new=ingest_result.asset_created,
|
||||
)
|
||||
|
||||
|
||||
def register_file_in_place(
|
||||
abs_path: str,
|
||||
name: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
) -> UploadResult:
|
||||
"""Register an already-saved file in the asset database without moving it.
|
||||
|
||||
Tags are derived from the filesystem path (root category + subfolder names),
|
||||
merged with any caller-provided tags, matching the behavior of the scanner.
|
||||
If the path is not under a known root, only the caller-provided tags are used.
|
||||
"""
|
||||
try:
|
||||
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||
except ValueError:
|
||||
path_tags = []
|
||||
merged_tags = normalize_tags([*path_tags, *tags])
|
||||
|
||||
try:
|
||||
digest, _ = hashing.compute_blake3_hash(abs_path)
|
||||
except ImportError as e:
|
||||
raise DependencyMissingError(str(e))
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"failed to hash file: {e}")
|
||||
asset_hash = "blake3:" + digest
|
||||
|
||||
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||
content_type = mime_type or (
|
||||
mimetypes.guess_type(abs_path, strict=False)[0]
|
||||
or "application/octet-stream"
|
||||
)
|
||||
|
||||
ingest_result = _ingest_file_from_path(
|
||||
abs_path=abs_path,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=size_bytes,
|
||||
mtime_ns=mtime_ns,
|
||||
mime_type=content_type,
|
||||
info_name=_sanitize_filename(name, fallback=digest),
|
||||
owner_id=owner_id,
|
||||
tags=merged_tags,
|
||||
tag_origin="upload",
|
||||
require_existing_tags=False,
|
||||
)
|
||||
reference_id = ingest_result.reference_id
|
||||
if not reference_id:
|
||||
raise RuntimeError("failed to create asset reference")
|
||||
|
||||
with create_session() as session:
|
||||
pair = fetch_reference_and_asset(
|
||||
session, reference_id=reference_id, owner_id=owner_id
|
||||
)
|
||||
if not pair:
|
||||
raise RuntimeError("inconsistent DB state after ingest")
|
||||
ref, asset = pair
|
||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||
|
||||
return UploadResult(
|
||||
ref=extract_reference_data(ref),
|
||||
asset=extract_asset_data(asset),
|
||||
tags=tag_names,
|
||||
created_new=ingest_result.asset_created,
|
||||
)
|
||||
|
||||
|
||||
def create_from_hash(
|
||||
hash_str: str,
|
||||
name: str,
|
||||
tags: list[str] | None = None,
|
||||
user_metadata: dict | None = None,
|
||||
owner_id: str = "",
|
||||
mime_type: str | None = None,
|
||||
preview_id: str | None = None,
|
||||
) -> UploadResult | None:
|
||||
canonical = hash_str.strip().lower()
|
||||
|
||||
try:
|
||||
result = _register_existing_asset(
|
||||
asset_hash=canonical,
|
||||
name=_sanitize_filename(
|
||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||
),
|
||||
user_metadata=user_metadata or {},
|
||||
tags=tags or [],
|
||||
tag_origin="manual",
|
||||
owner_id=owner_id,
|
||||
mime_type=mime_type,
|
||||
preview_id=preview_id,
|
||||
)
|
||||
except ValueError:
|
||||
logging.warning("create_from_hash: no asset found for hash %s", canonical)
|
||||
return None
|
||||
|
||||
return UploadResult(
|
||||
ref=result.ref,
|
||||
asset=result.asset,
|
||||
tags=result.tags,
|
||||
created_new=False,
|
||||
)
|
||||
327
app/assets/services/metadata_extract.py
Normal file
327
app/assets/services/metadata_extract.py
Normal file
@ -0,0 +1,327 @@
|
||||
"""Metadata extraction for asset scanning.
|
||||
|
||||
Tier 1: Filesystem metadata (zero parsing)
|
||||
Tier 2: Safetensors header metadata (fast JSON read only)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
import os
|
||||
import struct
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from utils.mime_types import init_mime_types
|
||||
|
||||
init_mime_types()
|
||||
|
||||
# Supported safetensors extensions
|
||||
SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
|
||||
|
||||
# Maximum safetensors header size to read (8MB)
|
||||
MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024
|
||||
|
||||
|
||||
@dataclass
|
||||
class ExtractedMetadata:
|
||||
"""Metadata extracted from a file during scanning."""
|
||||
|
||||
# Tier 1: Filesystem (always available)
|
||||
filename: str = ""
|
||||
file_path: str = "" # Full absolute path to the file
|
||||
content_length: int = 0
|
||||
content_type: str | None = None
|
||||
format: str = "" # file extension without dot
|
||||
|
||||
# Tier 2: Safetensors header (if available)
|
||||
base_model: str | None = None
|
||||
trained_words: list[str] | None = None
|
||||
air: str | None = None # CivitAI AIR identifier
|
||||
has_preview_images: bool = False
|
||||
|
||||
# Source provenance (populated if embedded in safetensors)
|
||||
source_url: str | None = None
|
||||
source_arn: str | None = None
|
||||
repo_url: str | None = None
|
||||
preview_url: str | None = None
|
||||
source_hash: str | None = None
|
||||
|
||||
# HuggingFace specific
|
||||
repo_id: str | None = None
|
||||
revision: str | None = None
|
||||
filepath: str | None = None
|
||||
resolve_url: str | None = None
|
||||
|
||||
def to_user_metadata(self) -> dict[str, Any]:
|
||||
"""Convert to user_metadata dict for AssetReference.user_metadata JSON field."""
|
||||
data: dict[str, Any] = {
|
||||
"filename": self.filename,
|
||||
"content_length": self.content_length,
|
||||
"format": self.format,
|
||||
}
|
||||
if self.file_path:
|
||||
data["file_path"] = self.file_path
|
||||
if self.content_type:
|
||||
data["content_type"] = self.content_type
|
||||
|
||||
# Tier 2 fields
|
||||
if self.base_model:
|
||||
data["base_model"] = self.base_model
|
||||
if self.trained_words:
|
||||
data["trained_words"] = self.trained_words
|
||||
if self.air:
|
||||
data["air"] = self.air
|
||||
if self.has_preview_images:
|
||||
data["has_preview_images"] = True
|
||||
|
||||
# Source provenance
|
||||
if self.source_url:
|
||||
data["source_url"] = self.source_url
|
||||
if self.source_arn:
|
||||
data["source_arn"] = self.source_arn
|
||||
if self.repo_url:
|
||||
data["repo_url"] = self.repo_url
|
||||
if self.preview_url:
|
||||
data["preview_url"] = self.preview_url
|
||||
if self.source_hash:
|
||||
data["source_hash"] = self.source_hash
|
||||
|
||||
# HuggingFace
|
||||
if self.repo_id:
|
||||
data["repo_id"] = self.repo_id
|
||||
if self.revision:
|
||||
data["revision"] = self.revision
|
||||
if self.filepath:
|
||||
data["filepath"] = self.filepath
|
||||
if self.resolve_url:
|
||||
data["resolve_url"] = self.resolve_url
|
||||
|
||||
return data
|
||||
|
||||
def to_meta_rows(self, reference_id: str) -> list[dict]:
|
||||
"""Convert to asset_reference_meta rows for typed/indexed querying."""
|
||||
rows: list[dict] = []
|
||||
|
||||
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
|
||||
if val:
|
||||
rows.append({
|
||||
"asset_reference_id": reference_id,
|
||||
"key": key,
|
||||
"ordinal": ordinal,
|
||||
"val_str": val[:2048] if len(val) > 2048 else val,
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
def add_num(key: str, val: int | float | None) -> None:
|
||||
if val is not None:
|
||||
rows.append({
|
||||
"asset_reference_id": reference_id,
|
||||
"key": key,
|
||||
"ordinal": 0,
|
||||
"val_str": None,
|
||||
"val_num": val,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
def add_bool(key: str, val: bool | None) -> None:
|
||||
if val is not None:
|
||||
rows.append({
|
||||
"asset_reference_id": reference_id,
|
||||
"key": key,
|
||||
"ordinal": 0,
|
||||
"val_str": None,
|
||||
"val_num": None,
|
||||
"val_bool": val,
|
||||
"val_json": None,
|
||||
})
|
||||
|
||||
# Tier 1
|
||||
add_str("filename", self.filename)
|
||||
add_num("content_length", self.content_length)
|
||||
add_str("content_type", self.content_type)
|
||||
add_str("format", self.format)
|
||||
|
||||
# Tier 2
|
||||
add_str("base_model", self.base_model)
|
||||
add_str("air", self.air)
|
||||
has_previews = self.has_preview_images if self.has_preview_images else None
|
||||
add_bool("has_preview_images", has_previews)
|
||||
|
||||
# trained_words as multiple rows with ordinals
|
||||
if self.trained_words:
|
||||
for i, word in enumerate(self.trained_words[:100]): # limit to 100 words
|
||||
add_str("trained_words", word, ordinal=i)
|
||||
|
||||
# Source provenance
|
||||
add_str("source_url", self.source_url)
|
||||
add_str("source_arn", self.source_arn)
|
||||
add_str("repo_url", self.repo_url)
|
||||
add_str("preview_url", self.preview_url)
|
||||
add_str("source_hash", self.source_hash)
|
||||
|
||||
# HuggingFace
|
||||
add_str("repo_id", self.repo_id)
|
||||
add_str("revision", self.revision)
|
||||
add_str("filepath", self.filepath)
|
||||
add_str("resolve_url", self.resolve_url)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def _read_safetensors_header(
|
||||
path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE
|
||||
) -> dict[str, Any] | None:
|
||||
"""Read only the JSON header from a safetensors file.
|
||||
|
||||
This is very fast - reads 8 bytes for header length, then the JSON header.
|
||||
No tensor data is loaded.
|
||||
|
||||
Args:
|
||||
path: Absolute path to safetensors file
|
||||
max_size: Maximum header size to read (default 8MB)
|
||||
|
||||
Returns:
|
||||
Parsed header dict or None if failed
|
||||
"""
|
||||
try:
|
||||
with open(path, "rb") as f:
|
||||
header_bytes = f.read(8)
|
||||
if len(header_bytes) < 8:
|
||||
return None
|
||||
length_of_header = struct.unpack("<Q", header_bytes)[0]
|
||||
if length_of_header > max_size:
|
||||
return None
|
||||
header_data = f.read(length_of_header)
|
||||
if len(header_data) < length_of_header:
|
||||
return None
|
||||
return json.loads(header_data.decode("utf-8"))
|
||||
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
|
||||
return None
|
||||
|
||||
|
||||
def _extract_safetensors_metadata(
|
||||
header: dict[str, Any], meta: ExtractedMetadata
|
||||
) -> None:
|
||||
"""Extract metadata from safetensors header __metadata__ section.
|
||||
|
||||
Modifies meta in-place.
|
||||
"""
|
||||
st_meta = header.get("__metadata__", {})
|
||||
if not isinstance(st_meta, dict):
|
||||
return
|
||||
|
||||
# Common model metadata
|
||||
meta.base_model = (
|
||||
st_meta.get("ss_base_model_version")
|
||||
or st_meta.get("modelspec.base_model")
|
||||
or st_meta.get("base_model")
|
||||
)
|
||||
|
||||
# Trained words / trigger words
|
||||
trained_words = st_meta.get("ss_tag_frequency")
|
||||
if trained_words and isinstance(trained_words, str):
|
||||
try:
|
||||
tag_freq = json.loads(trained_words)
|
||||
# Extract unique tags from all datasets
|
||||
all_tags: set[str] = set()
|
||||
for dataset_tags in tag_freq.values():
|
||||
if isinstance(dataset_tags, dict):
|
||||
all_tags.update(dataset_tags.keys())
|
||||
if all_tags:
|
||||
meta.trained_words = sorted(all_tags)[:100]
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# Direct trained_words field (some formats)
|
||||
if not meta.trained_words:
|
||||
tw = st_meta.get("trained_words")
|
||||
if isinstance(tw, str):
|
||||
try:
|
||||
parsed = json.loads(tw)
|
||||
if isinstance(parsed, list):
|
||||
meta.trained_words = [str(x) for x in parsed]
|
||||
else:
|
||||
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
|
||||
except json.JSONDecodeError:
|
||||
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
|
||||
elif isinstance(tw, list):
|
||||
meta.trained_words = [str(x) for x in tw]
|
||||
|
||||
# CivitAI AIR
|
||||
meta.air = st_meta.get("air") or st_meta.get("modelspec.air")
|
||||
|
||||
# Preview images (ssmd_cover_images)
|
||||
cover_images = st_meta.get("ssmd_cover_images")
|
||||
if cover_images:
|
||||
meta.has_preview_images = True
|
||||
|
||||
# Source provenance fields
|
||||
meta.source_url = st_meta.get("source_url")
|
||||
meta.source_arn = st_meta.get("source_arn")
|
||||
meta.repo_url = st_meta.get("repo_url")
|
||||
meta.preview_url = st_meta.get("preview_url")
|
||||
meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash")
|
||||
|
||||
# HuggingFace fields
|
||||
meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id")
|
||||
meta.revision = st_meta.get("revision") or st_meta.get("hf_revision")
|
||||
meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath")
|
||||
meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url")
|
||||
|
||||
|
||||
def extract_file_metadata(
|
||||
abs_path: str,
|
||||
stat_result: os.stat_result | None = None,
|
||||
relative_filename: str | None = None,
|
||||
) -> ExtractedMetadata:
|
||||
"""Extract metadata from a file using tier 1 and tier 2 methods.
|
||||
|
||||
Tier 1: Filesystem metadata from path and stat
|
||||
Tier 2: Safetensors header parsing if applicable
|
||||
|
||||
Args:
|
||||
abs_path: Absolute path to the file
|
||||
stat_result: Optional pre-fetched stat result (saves a syscall)
|
||||
relative_filename: Optional relative filename to use instead of basename
|
||||
(e.g., "flux/123/model.safetensors" for model paths)
|
||||
|
||||
Returns:
|
||||
ExtractedMetadata with all available fields populated
|
||||
"""
|
||||
meta = ExtractedMetadata()
|
||||
|
||||
# Tier 1: Filesystem metadata
|
||||
meta.filename = relative_filename or os.path.basename(abs_path)
|
||||
meta.file_path = abs_path
|
||||
_, ext = os.path.splitext(abs_path)
|
||||
meta.format = ext.lstrip(".").lower() if ext else ""
|
||||
|
||||
mime_type, _ = mimetypes.guess_type(abs_path)
|
||||
meta.content_type = mime_type
|
||||
|
||||
# Size from stat
|
||||
if stat_result is None:
|
||||
try:
|
||||
stat_result = os.stat(abs_path, follow_symlinks=True)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
if stat_result:
|
||||
meta.content_length = stat_result.st_size
|
||||
|
||||
# Tier 2: Safetensors header (if applicable and enabled)
|
||||
if ext.lower() in SAFETENSORS_EXTENSIONS:
|
||||
header = _read_safetensors_header(abs_path)
|
||||
if header:
|
||||
try:
|
||||
_extract_safetensors_metadata(header, meta)
|
||||
except Exception as e:
|
||||
logging.debug("Safetensors meta extract failed %s: %s", abs_path, e)
|
||||
|
||||
return meta
|
||||
167
app/assets/services/path_utils.py
Normal file
167
app/assets/services/path_utils.py
Normal file
@ -0,0 +1,167 @@
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Literal
|
||||
|
||||
import folder_paths
|
||||
from app.assets.helpers import normalize_tags
|
||||
|
||||
|
||||
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
|
||||
|
||||
|
||||
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
|
||||
"""Build list of (folder_name, base_paths[]) for all model locations.
|
||||
|
||||
Includes every category registered in folder_names_and_paths,
|
||||
regardless of whether its paths are under the main models_dir,
|
||||
but excludes non-model entries like custom_nodes.
|
||||
"""
|
||||
targets: list[tuple[str, list[str]]] = []
|
||||
for name, values in folder_paths.folder_names_and_paths.items():
|
||||
if name in _NON_MODEL_FOLDER_NAMES:
|
||||
continue
|
||||
paths, _exts = values[0], values[1]
|
||||
if paths:
|
||||
targets.append((name, paths))
|
||||
return targets
|
||||
|
||||
|
||||
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
|
||||
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
|
||||
if not tags:
|
||||
raise ValueError("tags must not be empty")
|
||||
root = tags[0].lower()
|
||||
if root == "models":
|
||||
if len(tags) < 2:
|
||||
raise ValueError("at least two tags required for model asset")
|
||||
try:
|
||||
bases = folder_paths.folder_names_and_paths[tags[1]][0]
|
||||
except KeyError:
|
||||
raise ValueError(f"unknown model category '{tags[1]}'")
|
||||
if not bases:
|
||||
raise ValueError(f"no base path configured for category '{tags[1]}'")
|
||||
base_dir = os.path.abspath(bases[0])
|
||||
raw_subdirs = tags[2:]
|
||||
elif root == "input":
|
||||
base_dir = os.path.abspath(folder_paths.get_input_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
elif root == "output":
|
||||
base_dir = os.path.abspath(folder_paths.get_output_directory())
|
||||
raw_subdirs = tags[1:]
|
||||
else:
|
||||
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
|
||||
_sep_chars = frozenset(("/", "\\", os.sep))
|
||||
for i in raw_subdirs:
|
||||
if i in (".", "..") or _sep_chars & set(i):
|
||||
raise ValueError("invalid path component in tags")
|
||||
|
||||
return base_dir, raw_subdirs if raw_subdirs else []
|
||||
|
||||
|
||||
def validate_path_within_base(candidate: str, base: str) -> None:
|
||||
cand_abs = Path(os.path.abspath(candidate))
|
||||
base_abs = Path(os.path.abspath(base))
|
||||
if not cand_abs.is_relative_to(base_abs):
|
||||
raise ValueError("destination escapes base directory")
|
||||
|
||||
|
||||
def compute_relative_filename(file_path: str) -> str | None:
|
||||
"""
|
||||
Return the model's path relative to the last well-known folder (the model category),
|
||||
using forward slashes, eg:
|
||||
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
|
||||
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
|
||||
|
||||
For non-model paths, returns None.
|
||||
"""
|
||||
try:
|
||||
root_category, rel_path = get_asset_category_and_relative_path(file_path)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
p = Path(rel_path)
|
||||
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
|
||||
if not parts:
|
||||
return None
|
||||
|
||||
if root_category == "models":
|
||||
# parts[0] is the category ("checkpoints", "vae", etc) – drop it
|
||||
inside = parts[1:] if len(parts) > 1 else [parts[0]]
|
||||
return "/".join(inside)
|
||||
return "/".join(parts) # input/output: keep all parts
|
||||
|
||||
|
||||
def get_asset_category_and_relative_path(
|
||||
file_path: str,
|
||||
) -> tuple[Literal["input", "output", "models"], str]:
|
||||
"""Determine which root category a file path belongs to.
|
||||
|
||||
Categories:
|
||||
- 'input': under folder_paths.get_input_directory()
|
||||
- 'output': under folder_paths.get_output_directory()
|
||||
- 'models': under any base path from get_comfy_models_folders()
|
||||
|
||||
Returns:
|
||||
(root_category, relative_path_inside_that_root)
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
"""
|
||||
fp_abs = os.path.abspath(file_path)
|
||||
|
||||
def _check_is_within(child: str, parent: str) -> bool:
|
||||
return Path(child).is_relative_to(parent)
|
||||
|
||||
def _compute_relative(child: str, parent: str) -> str:
|
||||
# Normalize relative path, stripping any leading ".." components
|
||||
# by anchoring to root (os.sep) then computing relpath back from it.
|
||||
return os.path.relpath(
|
||||
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
|
||||
)
|
||||
|
||||
# 1) input
|
||||
input_base = os.path.abspath(folder_paths.get_input_directory())
|
||||
if _check_is_within(fp_abs, input_base):
|
||||
return "input", _compute_relative(fp_abs, input_base)
|
||||
|
||||
# 2) output
|
||||
output_base = os.path.abspath(folder_paths.get_output_directory())
|
||||
if _check_is_within(fp_abs, output_base):
|
||||
return "output", _compute_relative(fp_abs, output_base)
|
||||
|
||||
# 3) models (check deepest matching base to avoid ambiguity)
|
||||
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
|
||||
for bucket, bases in get_comfy_models_folders():
|
||||
for b in bases:
|
||||
base_abs = os.path.abspath(b)
|
||||
if not _check_is_within(fp_abs, base_abs):
|
||||
continue
|
||||
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
|
||||
if best is None or cand[0] > best[0]:
|
||||
best = cand
|
||||
|
||||
if best is not None:
|
||||
_, bucket, rel_inside = best
|
||||
combined = os.path.join(bucket, rel_inside)
|
||||
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
|
||||
|
||||
raise ValueError(
|
||||
f"Path is not within input, output, or configured model bases: {file_path}"
|
||||
)
|
||||
|
||||
|
||||
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
|
||||
"""Return (name, tags) derived from a filesystem path.
|
||||
|
||||
- name: base filename with extension
|
||||
- tags: [root_category] + parent folder names in order
|
||||
|
||||
Raises:
|
||||
ValueError: path does not belong to any known root.
|
||||
"""
|
||||
root_category, some_path = get_asset_category_and_relative_path(file_path)
|
||||
p = Path(some_path)
|
||||
parent_parts = [
|
||||
part for part in p.parent.parts if part not in (".", "..", p.anchor)
|
||||
]
|
||||
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
|
||||
113
app/assets/services/schemas.py
Normal file
113
app/assets/services/schemas.py
Normal file
@ -0,0 +1,113 @@
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NamedTuple
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
|
||||
UserMetadata = dict[str, Any] | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetData:
|
||||
hash: str | None
|
||||
size_bytes: int | None
|
||||
mime_type: str | None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ReferenceData:
|
||||
"""Data transfer object for AssetReference."""
|
||||
|
||||
id: str
|
||||
name: str
|
||||
file_path: str | None
|
||||
user_metadata: UserMetadata
|
||||
preview_id: str | None
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
system_metadata: dict[str, Any] | None = None
|
||||
job_id: str | None = None
|
||||
last_access_time: datetime | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetDetailResult:
|
||||
ref: ReferenceData
|
||||
asset: AssetData | None
|
||||
tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class RegisterAssetResult:
|
||||
ref: ReferenceData
|
||||
asset: AssetData
|
||||
tags: list[str]
|
||||
created: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class IngestResult:
|
||||
asset_created: bool
|
||||
asset_updated: bool
|
||||
ref_created: bool
|
||||
ref_updated: bool
|
||||
reference_id: str | None
|
||||
|
||||
|
||||
class TagUsage(NamedTuple):
|
||||
name: str
|
||||
tag_type: str
|
||||
count: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class AssetSummaryData:
|
||||
ref: ReferenceData
|
||||
asset: AssetData | None
|
||||
tags: list[str]
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class ListAssetsResult:
|
||||
items: list[AssetSummaryData]
|
||||
total: int
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class DownloadResolutionResult:
|
||||
abs_path: str
|
||||
content_type: str
|
||||
download_name: str
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class UploadResult:
|
||||
ref: ReferenceData
|
||||
asset: AssetData
|
||||
tags: list[str]
|
||||
created_new: bool
|
||||
|
||||
|
||||
def extract_reference_data(ref: AssetReference) -> ReferenceData:
|
||||
return ReferenceData(
|
||||
id=ref.id,
|
||||
name=ref.name,
|
||||
file_path=ref.file_path,
|
||||
user_metadata=ref.user_metadata,
|
||||
preview_id=ref.preview_id,
|
||||
system_metadata=ref.system_metadata,
|
||||
job_id=ref.job_id,
|
||||
created_at=ref.created_at,
|
||||
updated_at=ref.updated_at,
|
||||
last_access_time=ref.last_access_time,
|
||||
)
|
||||
|
||||
|
||||
def extract_asset_data(asset: Asset | None) -> AssetData | None:
|
||||
if asset is None:
|
||||
return None
|
||||
return AssetData(
|
||||
hash=asset.hash,
|
||||
size_bytes=asset.size_bytes,
|
||||
mime_type=asset.mime_type,
|
||||
)
|
||||
98
app/assets/services/tagging.py
Normal file
98
app/assets/services/tagging.py
Normal file
@ -0,0 +1,98 @@
|
||||
from typing import Sequence
|
||||
|
||||
from app.assets.database.queries import (
|
||||
AddTagsResult,
|
||||
RemoveTagsResult,
|
||||
add_tags_to_reference,
|
||||
get_reference_with_owner_check,
|
||||
list_tags_with_usage,
|
||||
remove_tags_from_reference,
|
||||
)
|
||||
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
|
||||
from app.assets.services.schemas import TagUsage
|
||||
from app.database.db import create_session
|
||||
|
||||
|
||||
def apply_tags(
|
||||
reference_id: str,
|
||||
tags: list[str],
|
||||
origin: str = "manual",
|
||||
owner_id: str = "",
|
||||
) -> AddTagsResult:
|
||||
with create_session() as session:
|
||||
ref_row = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
result = add_tags_to_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=tags,
|
||||
origin=origin,
|
||||
create_if_missing=True,
|
||||
reference_row=ref_row,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def remove_tags(
|
||||
reference_id: str,
|
||||
tags: list[str],
|
||||
owner_id: str = "",
|
||||
) -> RemoveTagsResult:
|
||||
with create_session() as session:
|
||||
get_reference_with_owner_check(session, reference_id, owner_id)
|
||||
|
||||
result = remove_tags_from_reference(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
tags=tags,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def list_tags(
|
||||
prefix: str | None = None,
|
||||
limit: int = 100,
|
||||
offset: int = 0,
|
||||
order: str = "count_desc",
|
||||
include_zero: bool = True,
|
||||
owner_id: str = "",
|
||||
) -> tuple[list[TagUsage], int]:
|
||||
limit = max(1, min(1000, limit))
|
||||
offset = max(0, offset)
|
||||
|
||||
with create_session() as session:
|
||||
rows, total = list_tags_with_usage(
|
||||
session,
|
||||
prefix=prefix,
|
||||
limit=limit,
|
||||
offset=offset,
|
||||
include_zero=include_zero,
|
||||
order=order,
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||
|
||||
|
||||
def list_tag_histogram(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
exclude_tags: Sequence[str] | None = None,
|
||||
name_contains: str | None = None,
|
||||
metadata_filter: dict | None = None,
|
||||
limit: int = 100,
|
||||
) -> dict[str, int]:
|
||||
with create_session() as session:
|
||||
return list_tag_counts_for_filtered_assets(
|
||||
session,
|
||||
owner_id=owner_id,
|
||||
include_tags=include_tags,
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
)
|
||||
@ -3,6 +3,7 @@ import os
|
||||
import shutil
|
||||
from app.logger import log_startup_warning
|
||||
from utils.install_util import get_missing_requirements_message
|
||||
from filelock import FileLock, Timeout
|
||||
from comfy.cli_args import args
|
||||
|
||||
_DB_AVAILABLE = False
|
||||
@ -14,8 +15,12 @@ try:
|
||||
from alembic.config import Config
|
||||
from alembic.runtime.migration import MigrationContext
|
||||
from alembic.script import ScriptDirectory
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy import create_engine, event
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from sqlalchemy.pool import StaticPool
|
||||
|
||||
from app.database.models import Base
|
||||
import app.assets.database.models # noqa: F401 — register models with Base.metadata
|
||||
|
||||
_DB_AVAILABLE = True
|
||||
except ImportError as e:
|
||||
@ -65,9 +70,69 @@ def get_db_path():
|
||||
raise ValueError(f"Unsupported database URL '{url}'.")
|
||||
|
||||
|
||||
_db_lock = None
|
||||
|
||||
def _acquire_file_lock(db_path):
|
||||
"""Acquire an OS-level file lock to prevent multi-process access.
|
||||
|
||||
Uses filelock for cross-platform support (macOS, Linux, Windows).
|
||||
The OS automatically releases the lock when the process exits, even on crashes.
|
||||
"""
|
||||
global _db_lock
|
||||
lock_path = db_path + ".lock"
|
||||
_db_lock = FileLock(lock_path)
|
||||
try:
|
||||
_db_lock.acquire(timeout=0)
|
||||
except Timeout:
|
||||
raise RuntimeError(
|
||||
f"Could not acquire lock on database '{db_path}'. "
|
||||
"Another ComfyUI process may already be using it. "
|
||||
"Use --database-url to specify a separate database file."
|
||||
)
|
||||
|
||||
|
||||
def _is_memory_db(db_url):
|
||||
"""Check if the database URL refers to an in-memory SQLite database."""
|
||||
return db_url in ("sqlite:///:memory:", "sqlite://")
|
||||
|
||||
|
||||
def init_db():
|
||||
db_url = args.database_url
|
||||
logging.debug(f"Database URL: {db_url}")
|
||||
|
||||
if _is_memory_db(db_url):
|
||||
_init_memory_db(db_url)
|
||||
else:
|
||||
_init_file_db(db_url)
|
||||
|
||||
|
||||
def _init_memory_db(db_url):
|
||||
"""Initialize an in-memory SQLite database using metadata.create_all.
|
||||
|
||||
Alembic migrations don't work with in-memory SQLite because each
|
||||
connection gets its own separate database — tables created by Alembic's
|
||||
internal connection are lost immediately.
|
||||
"""
|
||||
engine = create_engine(
|
||||
db_url,
|
||||
poolclass=StaticPool,
|
||||
connect_args={"check_same_thread": False},
|
||||
)
|
||||
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
Base.metadata.create_all(engine)
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
def _init_file_db(db_url):
|
||||
"""Initialize a file-backed SQLite database using Alembic migrations."""
|
||||
db_path = get_db_path()
|
||||
db_exists = os.path.exists(db_path)
|
||||
|
||||
@ -75,6 +140,14 @@ def init_db():
|
||||
|
||||
# Check if we need to upgrade
|
||||
engine = create_engine(db_url)
|
||||
|
||||
# Enable foreign key enforcement for SQLite
|
||||
@event.listens_for(engine, "connect")
|
||||
def set_sqlite_pragma(dbapi_connection, connection_record):
|
||||
cursor = dbapi_connection.cursor()
|
||||
cursor.execute("PRAGMA foreign_keys=ON")
|
||||
cursor.close()
|
||||
|
||||
conn = engine.connect()
|
||||
|
||||
context = MigrationContext.configure(conn)
|
||||
@ -104,6 +177,12 @@ def init_db():
|
||||
logging.exception("Error upgrading database: ")
|
||||
raise e
|
||||
|
||||
# Acquire an OS-level file lock after migrations are complete.
|
||||
# Alembic uses its own connection, so we must wait until it's done
|
||||
# before locking — otherwise our own lock blocks the migration.
|
||||
conn.close()
|
||||
_acquire_file_lock(db_path)
|
||||
|
||||
global Session
|
||||
Session = sessionmaker(bind=engine)
|
||||
|
||||
|
||||
@ -1,9 +1,18 @@
|
||||
from typing import Any
|
||||
from datetime import datetime
|
||||
from sqlalchemy import MetaData
|
||||
from sqlalchemy.orm import DeclarativeBase
|
||||
|
||||
NAMING_CONVENTION = {
|
||||
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
|
||||
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
|
||||
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||
"pk": "pk_%(table_name)s",
|
||||
}
|
||||
|
||||
class Base(DeclarativeBase):
|
||||
pass
|
||||
metadata = MetaData(naming_convention=NAMING_CONVENTION)
|
||||
|
||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||
fields = obj.__table__.columns.keys()
|
||||
|
||||
@ -6,6 +6,7 @@ import uuid
|
||||
import glob
|
||||
import shutil
|
||||
import logging
|
||||
import tempfile
|
||||
from aiohttp import web
|
||||
from urllib import parse
|
||||
from comfy.cli_args import args
|
||||
@ -377,8 +378,15 @@ class UserManager():
|
||||
try:
|
||||
body = await request.read()
|
||||
|
||||
with open(path, "wb") as f:
|
||||
f.write(body)
|
||||
dir_name = os.path.dirname(path)
|
||||
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
|
||||
try:
|
||||
with os.fdopen(fd, "wb") as f:
|
||||
f.write(body)
|
||||
os.replace(tmp_path, path)
|
||||
except:
|
||||
os.unlink(tmp_path)
|
||||
raise
|
||||
except OSError as e:
|
||||
logging.warning(f"Error saving file '{path}': {e}")
|
||||
return web.Response(
|
||||
|
||||
@ -27,6 +27,7 @@ class AudioEncoderModel():
|
||||
self.model.eval()
|
||||
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
|
||||
self.model_sample_rate = 16000
|
||||
comfy.model_management.archive_model_dtypes(self.model)
|
||||
|
||||
def load_sd(self, sd):
|
||||
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())
|
||||
|
||||
@ -83,6 +83,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
|
||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
||||
|
||||
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
|
||||
|
||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||
|
||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||
@ -147,6 +149,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
|
||||
|
||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||
|
||||
@ -232,7 +235,7 @@ database_default_path = os.path.abspath(
|
||||
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
|
||||
)
|
||||
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
|
||||
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
|
||||
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
|
||||
|
||||
if comfy.options.args_parsing:
|
||||
args = parser.parse_args()
|
||||
@ -260,4 +263,6 @@ else:
|
||||
args.fast = set(args.fast)
|
||||
|
||||
def enables_dynamic_vram():
|
||||
if args.enable_dynamic_vram:
|
||||
return True
|
||||
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
||||
|
||||
@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
|
||||
"""COMBO type only. Specifies the configuration for a multi-select widget.
|
||||
Available after ComfyUI frontend v1.13.4
|
||||
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
|
||||
gradient_stops: NotRequired[list[list[float]]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``)."""
|
||||
gradient_stops: NotRequired[list[dict]]
|
||||
"""Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
|
||||
|
||||
|
||||
class HiddenInputTypeDict(TypedDict):
|
||||
|
||||
@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
|
||||
output_block[i:i + slice_size].copy_(block)
|
||||
|
||||
return output_fp4, to_blocked(output_block, flatten=False)
|
||||
|
||||
|
||||
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
|
||||
def roundup(x_val, multiple):
|
||||
return ((x_val + multiple - 1) // multiple) * multiple
|
||||
|
||||
if pad_32x:
|
||||
rows, cols = x.shape
|
||||
padded_rows = roundup(rows, 32)
|
||||
padded_cols = roundup(cols, 32)
|
||||
if padded_rows != rows or padded_cols != cols:
|
||||
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||
|
||||
F8_E4M3_MAX = 448.0
|
||||
E8M0_BIAS = 127
|
||||
BLOCK_SIZE = 32
|
||||
|
||||
rows, cols = x.shape
|
||||
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
|
||||
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
|
||||
|
||||
# E8M0 block scales (power-of-2 exponents)
|
||||
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
|
||||
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
|
||||
block_scales_e8m0 = exp_biased.to(torch.uint8)
|
||||
|
||||
zero_mask = (max_abs == 0)
|
||||
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
|
||||
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
|
||||
|
||||
# Scale per-block then stochastic round
|
||||
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
|
||||
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
|
||||
|
||||
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
|
||||
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)
|
||||
|
||||
@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
|
||||
return tensor * m_mult
|
||||
else:
|
||||
for d in modulation_dims:
|
||||
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
|
||||
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
|
||||
if m_add is not None:
|
||||
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
|
||||
tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
|
||||
return tensor
|
||||
|
||||
|
||||
@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module):
|
||||
del txt_k, img_k
|
||||
v = torch.cat((txt_v, img_v), dim=2)
|
||||
del txt_v, img_v
|
||||
|
||||
extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
for p in patch:
|
||||
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
|
||||
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
|
||||
|
||||
# run actual attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
if "attn1_output_patch" in transformer_patches:
|
||||
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
|
||||
patch = transformer_patches["attn1_output_patch"]
|
||||
for p in patch:
|
||||
attn = p(attn, extra_options)
|
||||
@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module):
|
||||
del qkv
|
||||
q, k = self.norm(q, k, v)
|
||||
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
for p in patch:
|
||||
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
|
||||
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
|
||||
|
||||
# compute attention
|
||||
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
|
||||
del q, k, v
|
||||
|
||||
@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
|
||||
|
||||
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
|
||||
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
|
||||
if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
|
||||
freqs_cis = freqs_cis[:, :, :x_.shape[2]]
|
||||
|
||||
x_out = freqs_cis[..., 0] * x_[..., 0]
|
||||
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])
|
||||
|
||||
@ -44,6 +44,22 @@ class FluxParams:
|
||||
txt_norm: bool = False
|
||||
|
||||
|
||||
def invert_slices(slices, length):
|
||||
sorted_slices = sorted(slices)
|
||||
result = []
|
||||
current = 0
|
||||
|
||||
for start, end in sorted_slices:
|
||||
if current < start:
|
||||
result.append((current, start))
|
||||
current = max(current, end)
|
||||
|
||||
if current < length:
|
||||
result.append((current, length))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Flux(nn.Module):
|
||||
"""
|
||||
Transformer model for flow matching on sequences.
|
||||
@ -138,6 +154,7 @@ class Flux(nn.Module):
|
||||
y: Tensor,
|
||||
guidance: Tensor = None,
|
||||
control = None,
|
||||
timestep_zero_index=None,
|
||||
transformer_options={},
|
||||
attn_mask: Tensor = None,
|
||||
) -> Tensor:
|
||||
@ -164,13 +181,9 @@ class Flux(nn.Module):
|
||||
txt = self.txt_norm(txt)
|
||||
txt = self.txt_in(txt)
|
||||
|
||||
vec_orig = vec
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
|
||||
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||
img = out["img"]
|
||||
txt = out["txt"]
|
||||
img_ids = out["img_ids"]
|
||||
@ -182,6 +195,24 @@ class Flux(nn.Module):
|
||||
else:
|
||||
pe = None
|
||||
|
||||
vec_orig = vec
|
||||
txt_vec = vec
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
modulation_dims = []
|
||||
batch = vec.shape[0] // 2
|
||||
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
|
||||
invert = invert_slices(timestep_zero_index, img.shape[1])
|
||||
for s in invert:
|
||||
modulation_dims.append((s[0], s[1], 0))
|
||||
for s in timestep_zero_index:
|
||||
modulation_dims.append((s[0], s[1], 1))
|
||||
extra_kwargs["modulation_dims_img"] = modulation_dims
|
||||
txt_vec = vec[:batch]
|
||||
|
||||
if self.params.global_modulation:
|
||||
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
|
||||
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
transformer_options["total_blocks"] = len(self.double_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
@ -195,7 +226,8 @@ class Flux(nn.Module):
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
transformer_options=args.get("transformer_options"),
|
||||
**extra_kwargs)
|
||||
return out
|
||||
|
||||
out = blocks_replace[("double_block", i)]({"img": img,
|
||||
@ -213,7 +245,8 @@ class Flux(nn.Module):
|
||||
vec=vec,
|
||||
pe=pe,
|
||||
attn_mask=attn_mask,
|
||||
transformer_options=transformer_options)
|
||||
transformer_options=transformer_options,
|
||||
**extra_kwargs)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_i = control.get("input")
|
||||
@ -230,6 +263,12 @@ class Flux(nn.Module):
|
||||
if self.params.global_modulation:
|
||||
vec, _ = self.single_stream_modulation(vec_orig)
|
||||
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
lambda a: 0 if a == 0 else a + txt.shape[1]
|
||||
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
|
||||
extra_kwargs["modulation_dims"] = modulation_dims_combined
|
||||
|
||||
transformer_options["total_blocks"] = len(self.single_blocks)
|
||||
transformer_options["block_type"] = "single"
|
||||
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
|
||||
@ -242,7 +281,8 @@ class Flux(nn.Module):
|
||||
vec=args["vec"],
|
||||
pe=args["pe"],
|
||||
attn_mask=args.get("attn_mask"),
|
||||
transformer_options=args.get("transformer_options"))
|
||||
transformer_options=args.get("transformer_options"),
|
||||
**extra_kwargs)
|
||||
return out
|
||||
|
||||
out = blocks_replace[("single_block", i)]({"img": img,
|
||||
@ -253,7 +293,7 @@ class Flux(nn.Module):
|
||||
{"original_block": block_wrap})
|
||||
img = out["img"]
|
||||
else:
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options)
|
||||
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
|
||||
|
||||
if control is not None: # Controlnet
|
||||
control_o = control.get("output")
|
||||
@ -264,7 +304,11 @@ class Flux(nn.Module):
|
||||
|
||||
img = img[:, txt.shape[1] :, ...]
|
||||
|
||||
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels)
|
||||
extra_kwargs = {}
|
||||
if timestep_zero_index is not None:
|
||||
extra_kwargs["modulation_dims"] = modulation_dims
|
||||
|
||||
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
|
||||
return img
|
||||
|
||||
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
|
||||
@ -312,13 +356,16 @@ class Flux(nn.Module):
|
||||
w_len = ((w_orig + (patch_size // 2)) // patch_size)
|
||||
img, img_ids = self.process_img(x, transformer_options=transformer_options)
|
||||
img_tokens = img.shape[1]
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
ref_num_tokens = []
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
|
||||
timestep_zero = ref_latents_method == "index_timestep_zero"
|
||||
for ref in ref_latents:
|
||||
if ref_latents_method == "index":
|
||||
if ref_latents_method in ("index", "index_timestep_zero"):
|
||||
index += self.params.ref_index_scale
|
||||
h_offset = 0
|
||||
w_offset = 0
|
||||
@ -342,6 +389,13 @@ class Flux(nn.Module):
|
||||
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
img = torch.cat([img, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
||||
|
||||
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
|
||||
|
||||
@ -349,6 +403,6 @@ class Flux(nn.Module):
|
||||
for i in self.params.txt_ids_dims:
|
||||
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
|
||||
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
|
||||
out = out[:, :img_tokens]
|
||||
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]
|
||||
|
||||
@ -343,6 +343,7 @@ class CrossAttention(nn.Module):
|
||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||
v,
|
||||
heads=self.num_heads,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
|
||||
out = self.out_proj(x)
|
||||
@ -412,6 +413,7 @@ class Attention(nn.Module):
|
||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||
value,
|
||||
heads=self.num_heads,
|
||||
low_precision_attention=False,
|
||||
)
|
||||
|
||||
x = self.out_proj(x)
|
||||
|
||||
@ -11,6 +11,7 @@ from .causal_conv3d import CausalConv3d
|
||||
from .pixel_norm import PixelNorm
|
||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||
import comfy.ops
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||
|
||||
ops = comfy.ops.disable_weight_init
|
||||
@ -536,7 +537,7 @@ class Decoder(nn.Module):
|
||||
mark_conv3d_ended(self.conv_out)
|
||||
sample = self.conv_out(sample, causal=self.causal)
|
||||
if sample is not None and sample.shape[2] > 0:
|
||||
output.append(sample)
|
||||
output.append(sample.to(comfy.model_management.intermediate_device()))
|
||||
return
|
||||
|
||||
up_block = self.up_blocks[idx]
|
||||
|
||||
@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
|
||||
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
if first_op_done == False:
|
||||
model_management.soft_empty_cache(True)
|
||||
if cleared_cache == False:
|
||||
|
||||
@ -258,7 +258,8 @@ def slice_attention(q, k, v):
|
||||
r1[:, :, i:end] = torch.bmm(v, s2)
|
||||
del s2
|
||||
break
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
model_management.soft_empty_cache(True)
|
||||
steps *= 2
|
||||
if steps > 128:
|
||||
@ -314,7 +315,8 @@ def pytorch_attention(q, k, v):
|
||||
try:
|
||||
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
|
||||
out = out.transpose(2, 3).reshape(orig_shape)
|
||||
except model_management.OOM_EXCEPTION:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
|
||||
oom_fallback = True
|
||||
if oom_fallback:
|
||||
|
||||
@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking(
|
||||
try:
|
||||
attn_probs = attn_scores.softmax(dim=-1)
|
||||
del attn_scores
|
||||
except model_management.OOM_EXCEPTION:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
|
||||
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
|
||||
torch.exp(attn_scores, out=attn_scores)
|
||||
|
||||
@ -149,6 +149,9 @@ class Attention(nn.Module):
|
||||
seq_img = hidden_states.shape[1]
|
||||
seq_txt = encoder_hidden_states.shape[1]
|
||||
|
||||
transformer_patches = transformer_options.get("patches", {})
|
||||
extra_options = transformer_options.copy()
|
||||
|
||||
# Project and reshape to BHND format (batch, heads, seq, dim)
|
||||
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
|
||||
@ -167,15 +170,22 @@ class Attention(nn.Module):
|
||||
joint_key = torch.cat([txt_key, img_key], dim=2)
|
||||
joint_value = torch.cat([txt_value, img_value], dim=2)
|
||||
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
if encoder_hidden_states_mask is not None:
|
||||
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
|
||||
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
|
||||
else:
|
||||
attn_mask = None
|
||||
|
||||
extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
|
||||
if "attn1_patch" in transformer_patches:
|
||||
patch = transformer_patches["attn1_patch"]
|
||||
for p in patch:
|
||||
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
|
||||
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
|
||||
|
||||
joint_query = apply_rope1(joint_query, image_rotary_emb)
|
||||
joint_key = apply_rope1(joint_key, image_rotary_emb)
|
||||
|
||||
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
|
||||
attn_mask, transformer_options=transformer_options,
|
||||
skip_reshape=True)
|
||||
@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
|
||||
timestep_zero_index = None
|
||||
if ref_latents is not None:
|
||||
ref_num_tokens = []
|
||||
h = 0
|
||||
w = 0
|
||||
index = 0
|
||||
@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
|
||||
hidden_states = torch.cat([hidden_states, kontext], dim=1)
|
||||
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
|
||||
ref_num_tokens.append(kontext.shape[1])
|
||||
if timestep_zero:
|
||||
if index > 0:
|
||||
timestep = torch.cat([timestep, timestep * 0], dim=0)
|
||||
timestep_zero_index = num_embeds
|
||||
transformer_options = transformer_options.copy()
|
||||
transformer_options["reference_image_num_tokens"] = ref_num_tokens
|
||||
|
||||
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
|
||||
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
hidden_states = self.img_in(hidden_states)
|
||||
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
||||
@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module):
|
||||
patches = transformer_options.get("patches", {})
|
||||
blocks_replace = patches_replace.get("dit", {})
|
||||
|
||||
if "post_input" in patches:
|
||||
for p in patches["post_input"]:
|
||||
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
|
||||
hidden_states = out["img"]
|
||||
encoder_hidden_states = out["txt"]
|
||||
img_ids = out["img_ids"]
|
||||
txt_ids = out["txt_ids"]
|
||||
|
||||
ids = torch.cat((txt_ids, img_ids), dim=1)
|
||||
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
|
||||
del ids, txt_ids, img_ids
|
||||
|
||||
transformer_options["total_blocks"] = len(self.transformer_blocks)
|
||||
transformer_options["block_type"] = "double"
|
||||
for i, block in enumerate(self.transformer_blocks):
|
||||
|
||||
@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}):
|
||||
for k in sdk:
|
||||
if k.endswith(".weight"):
|
||||
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
|
||||
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
|
||||
if tp > 0 and not k.startswith("clip_"):
|
||||
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
|
||||
|
||||
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
|
||||
clip_l_present = False
|
||||
|
||||
@ -1,9 +1,68 @@
|
||||
import math
|
||||
import ctypes
|
||||
import threading
|
||||
import dataclasses
|
||||
import torch
|
||||
from typing import NamedTuple
|
||||
|
||||
from comfy.quant_ops import QuantizedTensor
|
||||
|
||||
|
||||
class TensorFileSlice(NamedTuple):
|
||||
file_ref: object
|
||||
thread_id: int
|
||||
offset: int
|
||||
size: int
|
||||
|
||||
|
||||
def read_tensor_file_slice_into(tensor, destination):
|
||||
|
||||
if isinstance(tensor, QuantizedTensor):
|
||||
if not isinstance(destination, QuantizedTensor):
|
||||
return False
|
||||
if tensor._layout_cls != destination._layout_cls:
|
||||
return False
|
||||
|
||||
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
||||
return False
|
||||
|
||||
dst_orig_dtype = destination._params.orig_dtype
|
||||
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||
return True
|
||||
|
||||
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
||||
if info is None:
|
||||
return False
|
||||
|
||||
file_obj = info.file_ref
|
||||
if (destination.device.type != "cpu"
|
||||
or file_obj is None
|
||||
or threading.get_ident() != info.thread_id
|
||||
or destination.numel() * destination.element_size() < info.size):
|
||||
return False
|
||||
|
||||
if info.size == 0:
|
||||
return True
|
||||
|
||||
buf_type = ctypes.c_ubyte * info.size
|
||||
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||
|
||||
try:
|
||||
file_obj.seek(info.offset)
|
||||
done = 0
|
||||
while done < info.size:
|
||||
try:
|
||||
n = file_obj.readinto(view[done:])
|
||||
except OSError:
|
||||
return False
|
||||
if n <= 0:
|
||||
return False
|
||||
done += n
|
||||
return True
|
||||
finally:
|
||||
view.release()
|
||||
|
||||
class TensorGeometry(NamedTuple):
|
||||
shape: any
|
||||
dtype: torch.dtype
|
||||
|
||||
@ -1,4 +1,5 @@
|
||||
import json
|
||||
import comfy.memory_management
|
||||
import comfy.supported_models
|
||||
import comfy.supported_models_base
|
||||
import comfy.utils
|
||||
@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
|
||||
new[:old_weight.shape[0]] = old_weight
|
||||
old_weight = new
|
||||
|
||||
if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled:
|
||||
old_weight = old_weight.clone()
|
||||
|
||||
w = old_weight.narrow(offset[0], offset[1], offset[2])
|
||||
else:
|
||||
if comfy.memory_management.aimdo_enabled:
|
||||
weight = weight.clone()
|
||||
old_weight = weight
|
||||
w = weight
|
||||
w[:] = fun(weight)
|
||||
|
||||
@ -270,6 +270,23 @@ try:
|
||||
except:
|
||||
OOM_EXCEPTION = Exception
|
||||
|
||||
try:
|
||||
ACCELERATOR_ERROR = torch.AcceleratorError
|
||||
except AttributeError:
|
||||
ACCELERATOR_ERROR = RuntimeError
|
||||
|
||||
def is_oom(e):
|
||||
if isinstance(e, OOM_EXCEPTION):
|
||||
return True
|
||||
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
|
||||
discard_cuda_async_error()
|
||||
return True
|
||||
return False
|
||||
|
||||
def raise_non_oom(e):
|
||||
if not is_oom(e):
|
||||
raise e
|
||||
|
||||
XFORMERS_VERSION = ""
|
||||
XFORMERS_ENABLED_VAE = True
|
||||
if args.disable_xformers:
|
||||
@ -383,7 +400,7 @@ try:
|
||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||
ENABLE_PYTORCH_ATTENTION = True
|
||||
if rocm_version >= (7, 0):
|
||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||
@ -488,6 +505,28 @@ def module_size(module):
|
||||
module_mem += t.nbytes
|
||||
return module_mem
|
||||
|
||||
def module_mmap_residency(module, free=False):
|
||||
mmap_touched_mem = 0
|
||||
module_mem = 0
|
||||
bounced_mmaps = set()
|
||||
sd = module.state_dict()
|
||||
for k in sd:
|
||||
t = sd[k]
|
||||
module_mem += t.nbytes
|
||||
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
||||
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
||||
continue
|
||||
mmap_touched_mem += t.nbytes
|
||||
if not free:
|
||||
continue
|
||||
storage._comfy_tensor_mmap_touched = False
|
||||
mmap_obj = storage._comfy_tensor_mmap_refs[0]
|
||||
if mmap_obj in bounced_mmaps:
|
||||
continue
|
||||
mmap_obj.bounce()
|
||||
bounced_mmaps.add(mmap_obj)
|
||||
return mmap_touched_mem, module_mem
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model):
|
||||
self._set_model(model)
|
||||
@ -502,6 +541,7 @@ class LoadedModel:
|
||||
if model.parent is not None:
|
||||
self._parent_model = weakref.ref(model.parent)
|
||||
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
||||
self._patcher_finalizer.atexit = False
|
||||
|
||||
def _switch_parent(self):
|
||||
model = self._parent_model()
|
||||
@ -515,6 +555,9 @@ class LoadedModel:
|
||||
def model_memory(self):
|
||||
return self.model.model_size()
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return self.model.model_mmap_residency(free=free)
|
||||
|
||||
def model_loaded_memory(self):
|
||||
return self.model.loaded_size()
|
||||
|
||||
@ -545,6 +588,7 @@ class LoadedModel:
|
||||
|
||||
self.real_model = weakref.ref(real_model)
|
||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||
self.model_finalizer.atexit = False
|
||||
return real_model
|
||||
|
||||
def should_reload_model(self, force_patch_weights=False):
|
||||
@ -616,7 +660,7 @@ def extra_reserved_memory():
|
||||
def minimum_inference_memory():
|
||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
||||
cleanup_models_gc()
|
||||
unloaded_model = []
|
||||
can_unload = []
|
||||
@ -629,13 +673,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||
shift_model.currently_used = False
|
||||
|
||||
for x in sorted(can_unload):
|
||||
can_unload_sorted = sorted(can_unload)
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
memory_to_free = 1e32
|
||||
ram_to_free = 1e32
|
||||
pins_to_free = 1e32
|
||||
if not DISABLE_SMART_MEMORY:
|
||||
memory_to_free = memory_required - get_free_memory(device)
|
||||
ram_to_free = ram_required - get_free_ram()
|
||||
pins_to_free = pins_required - get_free_ram()
|
||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||
#don't actually unload dynamic models for the sake of other dynamic models
|
||||
#as that works on-demand.
|
||||
@ -644,9 +689,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
unloaded_model.append(i)
|
||||
if ram_to_free > 0:
|
||||
if pins_to_free > 0:
|
||||
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
||||
|
||||
for x in can_unload_sorted:
|
||||
i = x[-1]
|
||||
ram_to_free = ram_required - psutil.virtual_memory().available
|
||||
if ram_to_free <= 0 and i not in unloaded_model:
|
||||
continue
|
||||
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
|
||||
if resident_memory > 0:
|
||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
||||
|
||||
for i in sorted(unloaded_model, reverse=True):
|
||||
unloaded_models.append(current_loaded_models.pop(i))
|
||||
@ -712,17 +766,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
||||
|
||||
|
||||
total_memory_required = {}
|
||||
total_pins_required = {}
|
||||
total_ram_required = {}
|
||||
for loaded_model in models_to_load:
|
||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
||||
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
||||
#want to do.
|
||||
#FIXME: This should subtract off the to_load current pin consumption.
|
||||
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
||||
device = loaded_model.device
|
||||
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||
resident_memory, model_memory = loaded_model.model_mmap_residency()
|
||||
pinned_memory = loaded_model.model.pinned_memory_size()
|
||||
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
|
||||
#make this JIT to keep as much pinned as possible.
|
||||
pins_required = model_memory - pinned_memory
|
||||
ram_required = model_memory - resident_memory
|
||||
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
|
||||
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
||||
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||
device,
|
||||
for_dynamic=free_for_dynamic,
|
||||
pins_required=total_pins_required[device],
|
||||
ram_required=total_ram_required[device])
|
||||
|
||||
for device in total_memory_required:
|
||||
if device != torch.device("cpu"):
|
||||
@ -939,7 +1003,7 @@ def text_encoder_offload_device():
|
||||
def text_encoder_device():
|
||||
if args.gpu_only:
|
||||
return get_torch_device()
|
||||
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED):
|
||||
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled:
|
||||
if should_use_fp16(prioritize_performance=False):
|
||||
return get_torch_device()
|
||||
else:
|
||||
@ -988,6 +1052,12 @@ def intermediate_device():
|
||||
else:
|
||||
return torch.device("cpu")
|
||||
|
||||
def intermediate_dtype():
|
||||
if args.fp16_intermediates:
|
||||
return torch.float16
|
||||
else:
|
||||
return torch.float32
|
||||
|
||||
def vae_device():
|
||||
if args.cpu_vae:
|
||||
return torch.device("cpu")
|
||||
@ -1148,6 +1218,7 @@ def reset_cast_buffers():
|
||||
LARGEST_CASTED_WEIGHT = (None, 0)
|
||||
for offload_stream in STREAM_CAST_BUFFERS:
|
||||
offload_stream.synchronize()
|
||||
synchronize()
|
||||
STREAM_CAST_BUFFERS.clear()
|
||||
soft_empty_cache()
|
||||
|
||||
@ -1207,6 +1278,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
||||
dest_view = dest_views.pop(0)
|
||||
if tensor is None:
|
||||
continue
|
||||
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||
continue
|
||||
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
||||
storage._comfy_tensor_mmap_touched = True
|
||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||
|
||||
|
||||
@ -1262,7 +1338,7 @@ def discard_cuda_async_error():
|
||||
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
|
||||
_ = a + b
|
||||
synchronize()
|
||||
except torch.AcceleratorError:
|
||||
except RuntimeError:
|
||||
#Dump it! We already know about it from the synchronous return
|
||||
pass
|
||||
|
||||
@ -1644,6 +1720,19 @@ def supports_nvfp4_compute(device=None):
|
||||
|
||||
return True
|
||||
|
||||
def supports_mxfp8_compute(device=None):
|
||||
if not is_nvidia():
|
||||
return False
|
||||
|
||||
if torch_version_numeric < (2, 10):
|
||||
return False
|
||||
|
||||
props = torch.cuda.get_device_properties(device)
|
||||
if props.major < 10:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def extended_fp16_support():
|
||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||
if torch_version_numeric < (2, 7):
|
||||
|
||||
@ -297,6 +297,9 @@ class ModelPatcher:
|
||||
self.size = comfy.model_management.module_size(self.model)
|
||||
return self.size
|
||||
|
||||
def model_mmap_residency(self, free=False):
|
||||
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||
|
||||
def get_ram_usage(self):
|
||||
return self.model_size()
|
||||
|
||||
@ -599,6 +602,27 @@ class ModelPatcher:
|
||||
|
||||
return models
|
||||
|
||||
def model_patches_call_function(self, function_name="cleanup", arguments={}):
|
||||
to = self.model_options["transformer_options"]
|
||||
if "patches" in to:
|
||||
patches = to["patches"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for i in range(len(patch_list)):
|
||||
if hasattr(patch_list[i], function_name):
|
||||
getattr(patch_list[i], function_name)(**arguments)
|
||||
if "patches_replace" in to:
|
||||
patches = to["patches_replace"]
|
||||
for name in patches:
|
||||
patch_list = patches[name]
|
||||
for k in patch_list:
|
||||
if hasattr(patch_list[k], function_name):
|
||||
getattr(patch_list[k], function_name)(**arguments)
|
||||
if "model_function_wrapper" in self.model_options:
|
||||
wrap_func = self.model_options["model_function_wrapper"]
|
||||
if hasattr(wrap_func, function_name):
|
||||
getattr(wrap_func, function_name)(**arguments)
|
||||
|
||||
def model_dtype(self):
|
||||
if hasattr(self.model, "get_dtype"):
|
||||
return self.model.get_dtype()
|
||||
@ -715,8 +739,8 @@ class ModelPatcher:
|
||||
default = True # default random weights in non leaf modules
|
||||
break
|
||||
if default and default_device is not None:
|
||||
for param in params.values():
|
||||
param.data = param.data.to(device=default_device)
|
||||
for param_name, param in params.items():
|
||||
param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None))
|
||||
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
|
||||
module_mem = comfy.model_management.module_size(m)
|
||||
module_offload_mem = module_mem
|
||||
@ -1042,6 +1066,10 @@ class ModelPatcher:
|
||||
|
||||
return self.model.model_loaded_weight_memory - current_used
|
||||
|
||||
def pinned_memory_size(self):
|
||||
# Pinned memory pressure tracking is only implemented for DynamicVram loading
|
||||
return 0
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
pass
|
||||
|
||||
@ -1062,6 +1090,7 @@ class ModelPatcher:
|
||||
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
|
||||
|
||||
def cleanup(self):
|
||||
self.model_patches_call_function(function_name="cleanup")
|
||||
self.clean_hooks()
|
||||
if hasattr(self.model, "current_patcher"):
|
||||
self.model.current_patcher = None
|
||||
@ -1631,6 +1660,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
||||
|
||||
return freed
|
||||
|
||||
def pinned_memory_size(self):
|
||||
total = 0
|
||||
loading = self._load_list(for_dynamic=True)
|
||||
for x in loading:
|
||||
_, _, _, _, m, _ = x
|
||||
pin = comfy.pinned_memory.get_pin(m)
|
||||
if pin is not None:
|
||||
total += pin.numel() * pin.element_size()
|
||||
return total
|
||||
|
||||
def partially_unload_ram(self, ram_to_unload):
|
||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||
for x in loading:
|
||||
|
||||
236
comfy/ops.py
236
comfy/ops.py
@ -306,10 +306,40 @@ class CastWeightBiasOp:
|
||||
bias_function = []
|
||||
|
||||
class disable_weight_init:
|
||||
@staticmethod
|
||||
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
||||
missing_keys, unexpected_keys, weight_shape,
|
||||
bias_shape=None):
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k, v in state_dict.items():
|
||||
key = k[prefix_len:]
|
||||
if key == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
module.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif bias_shape is not None and key == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
module.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
if module.weight is None:
|
||||
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
|
||||
missing_keys.append(prefix + "weight")
|
||||
|
||||
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
|
||||
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
|
||||
missing_keys.append(prefix + "bias")
|
||||
|
||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||
|
||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
# don't trust subclasses that BYO state dict loader to call us.
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
||||
super().__init__(in_features, out_features, bias, device, dtype)
|
||||
return
|
||||
|
||||
@ -330,32 +360,21 @@ class disable_weight_init:
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||
prefix_len = len(prefix)
|
||||
for k,v in state_dict.items():
|
||||
if k[prefix_len:] == "weight":
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
elif k[prefix_len:] == "bias" and v is not None:
|
||||
if not assign_to_params_buffers:
|
||||
v = v.clone()
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
else:
|
||||
unexpected_keys.append(k)
|
||||
|
||||
#Reconcile default construction of the weight if its missing.
|
||||
if self.weight is None:
|
||||
v = torch.zeros(self.in_features, self.out_features)
|
||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"weight")
|
||||
if self.bias is None and self.comfy_need_lazy_init_bias:
|
||||
v = torch.zeros(self.out_features,)
|
||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||
missing_keys.append(prefix+"bias")
|
||||
disable_weight_init._lazy_load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
weight_shape=(self.in_features, self.out_features),
|
||||
bias_shape=(self.out_features,),
|
||||
)
|
||||
|
||||
|
||||
def reset_parameters(self):
|
||||
@ -547,6 +566,53 @@ class disable_weight_init:
|
||||
return super().forward(*args, **kwargs)
|
||||
|
||||
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
|
||||
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
|
||||
_freeze=False, device=None, dtype=None):
|
||||
# don't trust subclasses that BYO state dict loader to call us.
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
|
||||
norm_type, scale_grad_by_freq, sparse, _weight,
|
||||
_freeze, device, dtype)
|
||||
return
|
||||
|
||||
torch.nn.Module.__init__(self)
|
||||
self.num_embeddings = num_embeddings
|
||||
self.embedding_dim = embedding_dim
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
self.norm_type = norm_type
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
self.sparse = sparse
|
||||
# Keep shape/dtype visible for module introspection without reserving storage.
|
||||
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
|
||||
self.weight = torch.nn.Parameter(
|
||||
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
|
||||
requires_grad=False,
|
||||
)
|
||||
self.bias = None
|
||||
self.weight_comfy_model_dtype = dtype
|
||||
|
||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||
strict, missing_keys, unexpected_keys, error_msgs):
|
||||
|
||||
if (not comfy.model_management.WINDOWS
|
||||
or not comfy.memory_management.aimdo_enabled
|
||||
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||
missing_keys, unexpected_keys, error_msgs)
|
||||
disable_weight_init._lazy_load_from_state_dict(
|
||||
self,
|
||||
state_dict,
|
||||
prefix,
|
||||
local_metadata,
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
weight_shape=(self.num_embeddings, self.embedding_dim),
|
||||
)
|
||||
|
||||
def reset_parameters(self):
|
||||
self.bias = None
|
||||
return None
|
||||
@ -710,6 +776,71 @@ from .quant_ops import (
|
||||
)
|
||||
|
||||
|
||||
class QuantLinearFunc(torch.autograd.Function):
|
||||
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
|
||||
input_shape = input_float.shape
|
||||
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||
|
||||
# Quantize input (same as inference path)
|
||||
if layout_type is not None:
|
||||
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||
else:
|
||||
q_input = inp
|
||||
|
||||
w = weight.detach() if weight.requires_grad else weight
|
||||
b = bias.detach() if bias is not None and bias.requires_grad else bias
|
||||
|
||||
output = torch.nn.functional.linear(q_input, w, b)
|
||||
|
||||
# Restore original input shape
|
||||
if len(input_shape) > 2:
|
||||
output = output.unflatten(0, input_shape[:-1])
|
||||
|
||||
ctx.save_for_backward(input_float, weight)
|
||||
ctx.input_shape = input_shape
|
||||
ctx.has_bias = bias is not None
|
||||
ctx.compute_dtype = compute_dtype
|
||||
ctx.weight_requires_grad = weight.requires_grad
|
||||
|
||||
return output
|
||||
|
||||
@staticmethod
|
||||
@torch.autograd.function.once_differentiable
|
||||
def backward(ctx, grad_output):
|
||||
input_float, weight = ctx.saved_tensors
|
||||
compute_dtype = ctx.compute_dtype
|
||||
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||
|
||||
# Dequantize weight to compute dtype for backward matmul
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight_f = weight.dequantize().to(compute_dtype)
|
||||
else:
|
||||
weight_f = weight.to(compute_dtype)
|
||||
|
||||
# grad_input = grad_output @ weight
|
||||
grad_input = torch.mm(grad_2d, weight_f)
|
||||
if len(ctx.input_shape) > 2:
|
||||
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||
|
||||
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||
grad_weight = None
|
||||
if ctx.weight_requires_grad:
|
||||
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||
|
||||
# grad_bias
|
||||
grad_bias = None
|
||||
if ctx.has_bias:
|
||||
grad_bias = grad_2d.sum(dim=0)
|
||||
|
||||
return grad_input, grad_weight, grad_bias, None, None, None
|
||||
|
||||
|
||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||
class MixedPrecisionOps(manual_cast):
|
||||
_quant_config = quant_config
|
||||
@ -801,6 +932,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "mxfp8":
|
||||
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||
dtype=torch.uint8)
|
||||
|
||||
if block_scale is None:
|
||||
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||
|
||||
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||
|
||||
params = layout_cls.Params(
|
||||
scale=block_scale,
|
||||
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||
orig_shape=(self.out_features, self.in_features),
|
||||
)
|
||||
|
||||
elif self.quant_format == "nvfp4":
|
||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||
@ -888,10 +1035,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
#If cast needs to apply lora, it should be done in the compute dtype
|
||||
compute_dtype = input.dtype
|
||||
|
||||
if (getattr(self, 'layout_type', None) is not None and
|
||||
_use_quantized = (
|
||||
getattr(self, 'layout_type', None) is not None and
|
||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
||||
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||
)
|
||||
|
||||
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||
if (input.requires_grad and _use_quantized):
|
||||
|
||||
weight, bias, offload_stream = cast_bias_weight(
|
||||
self,
|
||||
input,
|
||||
offloadable=True,
|
||||
compute_dtype=compute_dtype,
|
||||
want_requant=True
|
||||
)
|
||||
|
||||
scale = getattr(self, 'input_scale', None)
|
||||
if scale is not None:
|
||||
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||
|
||||
output = QuantLinearFunc.apply(
|
||||
input, weight, bias, self.layout_type, scale, compute_dtype
|
||||
)
|
||||
|
||||
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||
return output
|
||||
|
||||
# Inference path (unchanged)
|
||||
if _use_quantized:
|
||||
|
||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||
@ -939,7 +1113,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
for key, param in self._parameters.items():
|
||||
if param is None:
|
||||
continue
|
||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
||||
p = fn(param)
|
||||
if p.is_inference():
|
||||
p = p.clone()
|
||||
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||
for key, buf in self._buffers.items():
|
||||
if buf is not None:
|
||||
self._buffers[key] = fn(buf)
|
||||
@ -950,12 +1127,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
|
||||
|
||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||
logging.info("Using mixed precision operations")
|
||||
disabled = set()
|
||||
if not nvfp4_compute:
|
||||
disabled.add("nvfp4")
|
||||
if not mxfp8_compute:
|
||||
disabled.add("mxfp8")
|
||||
if not fp8_compute:
|
||||
disabled.add("float8_e4m3fn")
|
||||
disabled.add("float8_e5m2")
|
||||
|
||||
@ -1,6 +1,7 @@
|
||||
import torch
|
||||
import comfy.model_management
|
||||
import comfy.memory_management
|
||||
import comfy_aimdo.host_buffer
|
||||
import comfy_aimdo.torch
|
||||
|
||||
from comfy.cli_args import args
|
||||
|
||||
@ -12,18 +13,31 @@ def pin_memory(module):
|
||||
return
|
||||
#FIXME: This is a RAM cache trigger event
|
||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||
pin = torch.empty((size,), dtype=torch.uint8)
|
||||
if comfy.model_management.pin_memory(pin):
|
||||
module._pin = pin
|
||||
else:
|
||||
|
||||
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
try:
|
||||
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
|
||||
except RuntimeError:
|
||||
module.pin_failed = True
|
||||
return False
|
||||
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
|
||||
module._pin_hostbuf = hostbuf
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||
return True
|
||||
|
||||
def unpin_memory(module):
|
||||
if get_pin(module) is None:
|
||||
return 0
|
||||
size = module._pin.numel() * module._pin.element_size()
|
||||
comfy.model_management.unpin_memory(module._pin)
|
||||
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||
|
||||
del module._pin
|
||||
del module._pin_hostbuf
|
||||
return size
|
||||
|
||||
@ -43,6 +43,18 @@ except ImportError as e:
|
||||
def get_layout_class(name):
|
||||
return None
|
||||
|
||||
_CK_MXFP8_AVAILABLE = False
|
||||
if _CK_AVAILABLE:
|
||||
try:
|
||||
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
|
||||
_CK_MXFP8_AVAILABLE = True
|
||||
except ImportError:
|
||||
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
|
||||
|
||||
if not _CK_MXFP8_AVAILABLE:
|
||||
class _CKMxfp8Layout:
|
||||
pass
|
||||
|
||||
import comfy.float
|
||||
|
||||
# ==============================================================================
|
||||
@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
if tensor.dim() != 2:
|
||||
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
|
||||
|
||||
orig_dtype = tensor.dtype
|
||||
orig_shape = tuple(tensor.shape)
|
||||
|
||||
padded_shape = cls.get_padded_shape(orig_shape)
|
||||
needs_padding = padded_shape != orig_shape
|
||||
|
||||
if stochastic_rounding > 0:
|
||||
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
|
||||
else:
|
||||
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
|
||||
|
||||
params = cls.Params(
|
||||
scale=block_scale,
|
||||
orig_dtype=orig_dtype,
|
||||
orig_shape=orig_shape,
|
||||
)
|
||||
return qdata, params
|
||||
|
||||
|
||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||
@classmethod
|
||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||
@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||
|
||||
QUANT_ALGOS = {
|
||||
"float8_e4m3fn": {
|
||||
@ -157,6 +196,14 @@ QUANT_ALGOS = {
|
||||
},
|
||||
}
|
||||
|
||||
if _CK_MXFP8_AVAILABLE:
|
||||
QUANT_ALGOS["mxfp8"] = {
|
||||
"storage_t": torch.float8_e4m3fn,
|
||||
"parameters": {"weight_scale", "input_scale"},
|
||||
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
|
||||
"group_size": 32,
|
||||
}
|
||||
|
||||
|
||||
# ==============================================================================
|
||||
# Re-exports for backward compatibility
|
||||
|
||||
33
comfy/sd.py
33
comfy/sd.py
@ -871,13 +871,16 @@ class VAE:
|
||||
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||
return pixels
|
||||
|
||||
def vae_output_dtype(self):
|
||||
return model_management.intermediate_dtype()
|
||||
|
||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
output = self.process_output(
|
||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||
@ -887,16 +890,16 @@ class VAE:
|
||||
|
||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||
if samples.ndim == 3:
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
else:
|
||||
og_shape = samples.shape
|
||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||
|
||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||
|
||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||
@ -905,7 +908,7 @@ class VAE:
|
||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||
@ -914,7 +917,7 @@ class VAE:
|
||||
|
||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||
if self.latent_dim == 1:
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
out_channels = self.latent_channels
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
else:
|
||||
@ -923,7 +926,7 @@ class VAE:
|
||||
tile_x = tile_x // extra_channel_size
|
||||
overlap = overlap // extra_channel_size
|
||||
upscale_amount = 1 / self.downscale_ratio
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
|
||||
|
||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||
if self.latent_dim == 1:
|
||||
@ -932,7 +935,7 @@ class VAE:
|
||||
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||
|
||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||
|
||||
def decode(self, samples_in, vae_options={}):
|
||||
@ -950,11 +953,12 @@ class VAE:
|
||||
|
||||
for x in range(0, samples_in.shape[0], batch_number):
|
||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
|
||||
if pixel_samples is None:
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
pixel_samples[x:x+batch_number] = out
|
||||
except model_management.OOM_EXCEPTION:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
@ -1024,12 +1028,13 @@ class VAE:
|
||||
samples = None
|
||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||
if samples is None:
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||
samples[x:x + batch_number] = out
|
||||
|
||||
except model_management.OOM_EXCEPTION:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
|
||||
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
|
||||
#exception and the exception itself refs them all until we get out of this except block.
|
||||
|
||||
@ -20,6 +20,8 @@
|
||||
import torch
|
||||
import math
|
||||
import struct
|
||||
import ctypes
|
||||
import os
|
||||
import comfy.memory_management
|
||||
import safetensors.torch
|
||||
import numpy as np
|
||||
@ -32,7 +34,7 @@ from einops import rearrange
|
||||
from comfy.cli_args import args
|
||||
import json
|
||||
import time
|
||||
import mmap
|
||||
import threading
|
||||
import warnings
|
||||
|
||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||
@ -81,14 +83,17 @@ _TYPES = {
|
||||
}
|
||||
|
||||
def load_safetensors(ckpt):
|
||||
f = open(ckpt, "rb")
|
||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
||||
mv = memoryview(mapping)
|
||||
import comfy_aimdo.model_mmap
|
||||
|
||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
||||
f = open(ckpt, "rb", buffering=0)
|
||||
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||
file_size = os.path.getsize(ckpt)
|
||||
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||
|
||||
mv = mv[8 + header_size:]
|
||||
header_size = struct.unpack("<Q", mv[:8])[0]
|
||||
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
|
||||
|
||||
mv = mv[(data_base_offset := 8 + header_size):]
|
||||
|
||||
sd = {}
|
||||
for name, info in header.items():
|
||||
@ -102,7 +107,14 @@ def load_safetensors(ckpt):
|
||||
with warnings.catch_warnings():
|
||||
#We are working with read-only RAM by design
|
||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||
storage = tensor.untyped_storage()
|
||||
setattr(storage,
|
||||
"_comfy_tensor_file_slice",
|
||||
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||
setattr(storage, "_comfy_tensor_mmap_touched", False)
|
||||
sd[name] = tensor
|
||||
|
||||
return sd, header.get("__metadata__", {}),
|
||||
|
||||
@ -885,6 +897,10 @@ def set_attr(obj, attr, value):
|
||||
return prev
|
||||
|
||||
def set_attr_param(obj, attr, value):
|
||||
# Clone inference tensors (created under torch.inference_mode) since
|
||||
# their version counter is frozen and nn.Parameter() cannot wrap them.
|
||||
if (not torch.is_inference_mode_enabled()) and value.is_inference():
|
||||
value = value.clone()
|
||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||
|
||||
def set_attr_buffer(obj, attr, value):
|
||||
|
||||
@ -15,6 +15,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
|
||||
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
|
||||
"extension": {"manager": {"supports_v4": True}},
|
||||
"node_replacements": True,
|
||||
"assets": args.enable_assets,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
super().__init__()
|
||||
self.node_replacement = self.NodeReplacement()
|
||||
self.execution = self.Execution()
|
||||
self.caching = self.Caching()
|
||||
|
||||
class NodeReplacement(ProxiedSingleton):
|
||||
async def register(self, node_replace: io.NodeReplace) -> None:
|
||||
@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase):
|
||||
image=to_display,
|
||||
)
|
||||
|
||||
class Caching(ProxiedSingleton):
|
||||
"""
|
||||
External cache provider API for sharing cached node outputs
|
||||
across ComfyUI instances.
|
||||
|
||||
Example::
|
||||
|
||||
from comfy_api.latest import Caching
|
||||
|
||||
class MyCacheProvider(Caching.CacheProvider):
|
||||
async def on_lookup(self, context):
|
||||
... # check external storage
|
||||
|
||||
async def on_store(self, context, value):
|
||||
... # store to external storage
|
||||
|
||||
Caching.register_provider(MyCacheProvider())
|
||||
"""
|
||||
from ._caching import CacheProvider, CacheContext, CacheValue
|
||||
|
||||
async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
|
||||
"""Register an external cache provider. Providers are called in registration order."""
|
||||
from comfy_execution.cache_provider import register_cache_provider
|
||||
register_cache_provider(provider)
|
||||
|
||||
async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
|
||||
"""Unregister a previously registered cache provider."""
|
||||
from comfy_execution.cache_provider import unregister_cache_provider
|
||||
unregister_cache_provider(provider)
|
||||
|
||||
class ComfyExtension(ABC):
|
||||
async def on_load(self) -> None:
|
||||
"""
|
||||
@ -116,6 +147,9 @@ class Types:
|
||||
VOXEL = VOXEL
|
||||
File3D = File3D
|
||||
|
||||
|
||||
Caching = ComfyAPI_latest.Caching
|
||||
|
||||
ComfyAPI = ComfyAPI_latest
|
||||
|
||||
# Create a synchronous version of the API
|
||||
@ -135,6 +169,7 @@ __all__ = [
|
||||
"Input",
|
||||
"InputImpl",
|
||||
"Types",
|
||||
"Caching",
|
||||
"ComfyExtension",
|
||||
"io",
|
||||
"IO",
|
||||
|
||||
42
comfy_api/latest/_caching.py
Normal file
42
comfy_api/latest/_caching.py
Normal file
@ -0,0 +1,42 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheContext:
|
||||
node_id: str
|
||||
class_type: str
|
||||
cache_key_hash: str # SHA256 hex digest
|
||||
|
||||
|
||||
@dataclass
|
||||
class CacheValue:
|
||||
outputs: list
|
||||
ui: dict = None
|
||||
|
||||
|
||||
class CacheProvider(ABC):
|
||||
"""Abstract base class for external cache providers.
|
||||
Exceptions from provider methods are caught by the caller and never break execution.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||
"""Called after local store. Dispatched via asyncio.create_task."""
|
||||
pass
|
||||
|
||||
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
|
||||
"""Return False to skip external caching for this node. Default: True."""
|
||||
return True
|
||||
|
||||
def on_prompt_start(self, prompt_id: str) -> None:
|
||||
pass
|
||||
|
||||
def on_prompt_end(self, prompt_id: str) -> None:
|
||||
pass
|
||||
@ -272,7 +272,7 @@ class VideoFromFile(VideoInput):
|
||||
has_first_frame = False
|
||||
for frame in frames:
|
||||
offset_seconds = start_time - frame.pts * audio_stream.time_base
|
||||
to_skip = int(offset_seconds * audio_stream.sample_rate)
|
||||
to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
|
||||
if to_skip < frame.samples:
|
||||
has_first_frame = True
|
||||
break
|
||||
@ -280,7 +280,7 @@ class VideoFromFile(VideoInput):
|
||||
audio_frames.append(frame.to_ndarray()[..., to_skip:])
|
||||
|
||||
for frame in frames:
|
||||
if frame.time > start_time + self.__duration:
|
||||
if self.__duration and frame.time > start_time + self.__duration:
|
||||
break
|
||||
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
|
||||
if len(audio_frames) > 0:
|
||||
|
||||
@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
|
||||
'''Float input.'''
|
||||
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
|
||||
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
|
||||
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None,
|
||||
display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
|
||||
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
|
||||
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
|
||||
self.min = min
|
||||
|
||||
68
comfy_api_nodes/apis/reve.py
Normal file
68
comfy_api_nodes/apis/reve.py
Normal file
@ -0,0 +1,68 @@
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class RevePostprocessingOperation(BaseModel):
|
||||
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
|
||||
upscale_factor: int | None = Field(
|
||||
None,
|
||||
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
|
||||
ge=2,
|
||||
le=4,
|
||||
)
|
||||
|
||||
|
||||
class ReveImageCreateRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
aspect_ratio: str | None = Field(...)
|
||||
version: str = Field(...)
|
||||
test_time_scaling: int = Field(
|
||||
...,
|
||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
||||
ge=1,
|
||||
le=15,
|
||||
)
|
||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
||||
None, description="Optional postprocessing operations to apply after generation."
|
||||
)
|
||||
|
||||
|
||||
class ReveImageEditRequest(BaseModel):
|
||||
edit_instruction: str = Field(...)
|
||||
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
|
||||
aspect_ratio: str | None = Field(...)
|
||||
version: str = Field(...)
|
||||
test_time_scaling: int | None = Field(
|
||||
...,
|
||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
||||
ge=1,
|
||||
le=15,
|
||||
)
|
||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
||||
None, description="Optional postprocessing operations to apply after generation."
|
||||
)
|
||||
|
||||
|
||||
class ReveImageRemixRequest(BaseModel):
|
||||
prompt: str = Field(...)
|
||||
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
|
||||
aspect_ratio: str | None = Field(...)
|
||||
version: str = Field(...)
|
||||
test_time_scaling: int | None = Field(
|
||||
...,
|
||||
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
|
||||
ge=1,
|
||||
le=15,
|
||||
)
|
||||
postprocessing: list[RevePostprocessingOperation] | None = Field(
|
||||
None, description="Optional postprocessing operations to apply after generation."
|
||||
)
|
||||
|
||||
|
||||
class ReveImageResponse(BaseModel):
|
||||
image: str | None = Field(None, description="The base64 encoded image data.")
|
||||
request_id: str | None = Field(None, description="A unique id for the request.")
|
||||
credits_used: float | None = Field(None, description="The number of credits used for this request.")
|
||||
version: str | None = Field(None, description="The specific model version used.")
|
||||
content_violation: bool | None = Field(
|
||||
None, description="Indicates whether the generated image violates the content policy."
|
||||
)
|
||||
@ -1,3 +1,7 @@
|
||||
import zipfile
|
||||
from io import BytesIO
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
@ -17,7 +21,10 @@ from comfy_api_nodes.apis.hunyuan3d import (
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
bytesio_to_image_tensor,
|
||||
download_url_to_bytesio,
|
||||
download_url_to_file_3d,
|
||||
download_url_to_image_tensor,
|
||||
downscale_image_tensor_by_max_side,
|
||||
poll_op,
|
||||
sync_op,
|
||||
@ -36,6 +43,68 @@ def _is_tencent_rate_limited(status: int, body: object) -> bool:
|
||||
)
|
||||
|
||||
|
||||
class ObjZipResult:
|
||||
__slots__ = ("obj", "texture", "metallic", "normal", "roughness")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
obj: Types.File3D,
|
||||
texture: Input.Image | None = None,
|
||||
metallic: Input.Image | None = None,
|
||||
normal: Input.Image | None = None,
|
||||
roughness: Input.Image | None = None,
|
||||
):
|
||||
self.obj = obj
|
||||
self.texture = texture
|
||||
self.metallic = metallic
|
||||
self.normal = normal
|
||||
self.roughness = roughness
|
||||
|
||||
|
||||
async def download_and_extract_obj_zip(url: str) -> ObjZipResult:
|
||||
"""The Tencent API returns OBJ results as ZIP archives containing the .obj mesh, and texture images.
|
||||
|
||||
When PBR is enabled, the ZIP may contain additional metallic, normal, and roughness maps
|
||||
identified by their filename suffixes.
|
||||
"""
|
||||
data = BytesIO()
|
||||
await download_url_to_bytesio(url, data)
|
||||
data.seek(0)
|
||||
if not zipfile.is_zipfile(data):
|
||||
data.seek(0)
|
||||
return ObjZipResult(obj=Types.File3D(source=data, file_format="obj"))
|
||||
data.seek(0)
|
||||
obj_bytes = None
|
||||
textures: dict[str, Input.Image] = {}
|
||||
with zipfile.ZipFile(data) as zf:
|
||||
for name in zf.namelist():
|
||||
lower = name.lower()
|
||||
if lower.endswith(".obj"):
|
||||
obj_bytes = zf.read(name)
|
||||
elif any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
|
||||
stem = lower.rsplit(".", 1)[0]
|
||||
tensor = bytesio_to_image_tensor(BytesIO(zf.read(name)), mode="RGB")
|
||||
matched_key = "texture"
|
||||
for suffix, key in {
|
||||
"_metallic": "metallic",
|
||||
"_normal": "normal",
|
||||
"_roughness": "roughness",
|
||||
}.items():
|
||||
if stem.endswith(suffix):
|
||||
matched_key = key
|
||||
break
|
||||
textures[matched_key] = tensor
|
||||
if obj_bytes is None:
|
||||
raise ValueError("ZIP archive does not contain an OBJ file.")
|
||||
return ObjZipResult(
|
||||
obj=Types.File3D(source=BytesIO(obj_bytes), file_format="obj"),
|
||||
texture=textures.get("texture"),
|
||||
metallic=textures.get("metallic"),
|
||||
normal=textures.get("normal"),
|
||||
roughness=textures.get("roughness"),
|
||||
)
|
||||
|
||||
|
||||
def get_file_from_response(
|
||||
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
|
||||
) -> ResultFile3D | None:
|
||||
@ -93,6 +162,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.Image.Output(display_name="texture_image"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -151,14 +221,14 @@ class TencentTextToModelNode(IO.ComfyNode):
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
obj_result.obj,
|
||||
obj_result.texture,
|
||||
)
|
||||
|
||||
|
||||
@ -211,6 +281,10 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.Image.Output(display_name="texture_image"),
|
||||
IO.Image.Output(display_name="optional_metallic"),
|
||||
IO.Image.Output(display_name="optional_normal"),
|
||||
IO.Image.Output(display_name="optional_roughness"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -304,14 +378,17 @@ class TencentImageToModelNode(IO.ComfyNode):
|
||||
response_model=To3DProTaskResultResponse,
|
||||
status_extractor=lambda r: r.Status,
|
||||
)
|
||||
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
||||
return IO.NodeOutput(
|
||||
f"{task_id}.glb",
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||
),
|
||||
await download_url_to_file_3d(
|
||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
||||
),
|
||||
obj_result.obj,
|
||||
obj_result.texture,
|
||||
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
|
||||
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
|
||||
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
|
||||
)
|
||||
|
||||
|
||||
@ -431,7 +508,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DGLB.Output(display_name="GLB"),
|
||||
IO.File3DFBX.Output(display_name="FBX"),
|
||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||
IO.Image.Output(display_name="texture_image"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
@ -480,7 +558,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
||||
)
|
||||
return IO.NodeOutput(
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "texture_image").Url),
|
||||
)
|
||||
|
||||
|
||||
@ -654,7 +733,7 @@ class TencentHunyuan3DExtension(ComfyExtension):
|
||||
TencentTextToModelNode,
|
||||
TencentImageToModelNode,
|
||||
TencentModelTo3DUVNode,
|
||||
# Tencent3DTextureEditNode,
|
||||
Tencent3DTextureEditNode,
|
||||
Tencent3DPartNode,
|
||||
TencentSmartTopologyNode,
|
||||
]
|
||||
|
||||
@ -1459,6 +1459,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
||||
node_id="KlingOmniProEditVideoNode",
|
||||
display_name="Kling 3.0 Omni Edit Video",
|
||||
category="api node/video/Kling",
|
||||
essentials_category="Video Generation",
|
||||
description="Edit an existing video with the latest model from Kling.",
|
||||
inputs=[
|
||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||
|
||||
@ -833,6 +833,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
|
||||
node_id="RecraftVectorizeImageNode",
|
||||
display_name="Recraft Vectorize Image",
|
||||
category="api node/image/Recraft",
|
||||
essentials_category="Image Tools",
|
||||
description="Generates SVG synchronously from an input image.",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
|
||||
395
comfy_api_nodes/nodes_reve.py
Normal file
395
comfy_api_nodes/nodes_reve.py
Normal file
@ -0,0 +1,395 @@
|
||||
from io import BytesIO
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.reve import (
|
||||
ReveImageCreateRequest,
|
||||
ReveImageEditRequest,
|
||||
ReveImageRemixRequest,
|
||||
RevePostprocessingOperation,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
ApiEndpoint,
|
||||
bytesio_to_image_tensor,
|
||||
sync_op_raw,
|
||||
tensor_to_base64_string,
|
||||
validate_string,
|
||||
)
|
||||
|
||||
|
||||
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
|
||||
ops = []
|
||||
if upscale["upscale"] == "enabled":
|
||||
ops.append(
|
||||
RevePostprocessingOperation(
|
||||
process="upscale",
|
||||
upscale_factor=upscale["upscale_factor"],
|
||||
)
|
||||
)
|
||||
if remove_background:
|
||||
ops.append(RevePostprocessingOperation(process="remove_background"))
|
||||
return ops or None
|
||||
|
||||
|
||||
def _postprocessing_inputs():
|
||||
return [
|
||||
IO.DynamicCombo.Input(
|
||||
"upscale",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("disabled", []),
|
||||
IO.DynamicCombo.Option(
|
||||
"enabled",
|
||||
[
|
||||
IO.Int.Input(
|
||||
"upscale_factor",
|
||||
default=2,
|
||||
min=2,
|
||||
max=4,
|
||||
step=1,
|
||||
tooltip="Upscale factor (2x, 3x, or 4x).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="Upscale the generated image. May add additional cost.",
|
||||
),
|
||||
IO.Boolean.Input(
|
||||
"remove_background",
|
||||
default=False,
|
||||
tooltip="Remove the background from the generated image. May add additional cost.",
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def _reve_price_extractor(headers: dict) -> float | None:
|
||||
credits_used = headers.get("x-reve-credits-used")
|
||||
if credits_used is not None:
|
||||
return float(credits_used) / 524.48
|
||||
return None
|
||||
|
||||
|
||||
def _reve_response_header_validator(headers: dict) -> None:
|
||||
error_code = headers.get("x-reve-error-code")
|
||||
if error_code:
|
||||
raise ValueError(f"Reve API error: {error_code}")
|
||||
if headers.get("x-reve-content-violation", "").lower() == "true":
|
||||
raise ValueError("The generated image was flagged for content policy violation.")
|
||||
|
||||
|
||||
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
|
||||
return [
|
||||
IO.DynamicCombo.Option(
|
||||
version,
|
||||
[
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
options=aspect_ratios,
|
||||
tooltip="Aspect ratio of the output image.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"test_time_scaling",
|
||||
default=1,
|
||||
min=1,
|
||||
max=5,
|
||||
step=1,
|
||||
tooltip="Higher values produce better images but cost more credits.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
for version in versions
|
||||
]
|
||||
|
||||
|
||||
class ReveImageCreateNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageCreateNode",
|
||||
display_name="Reve Image Create",
|
||||
category="api node/image/Reve",
|
||||
description="Generate images from text descriptions using Reve.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the desired image. Maximum 2560 characters.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=_model_inputs(
|
||||
["reve-create@20250915"],
|
||||
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
|
||||
),
|
||||
tooltip="Model version to use for generation.",
|
||||
),
|
||||
*_postprocessing_inputs(),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
upscale: dict,
|
||||
remove_background: bool,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2560)
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path="/proxy/reve/v1/image/create",
|
||||
method="POST",
|
||||
headers={"Accept": "image/webp"},
|
||||
),
|
||||
as_binary=True,
|
||||
price_extractor=_reve_price_extractor,
|
||||
response_header_validator=_reve_response_header_validator,
|
||||
data=ReveImageCreateRequest(
|
||||
prompt=prompt,
|
||||
aspect_ratio=model["aspect_ratio"],
|
||||
version=model["model"],
|
||||
test_time_scaling=model["test_time_scaling"],
|
||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
||||
|
||||
|
||||
class ReveImageEditNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageEditNode",
|
||||
display_name="Reve Image Edit",
|
||||
category="api node/image/Reve",
|
||||
description="Edit images using natural language instructions with Reve.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="The image to edit."),
|
||||
IO.String.Input(
|
||||
"edit_instruction",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=_model_inputs(
|
||||
["reve-edit@20250915", "reve-edit-fast@20251030"],
|
||||
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
|
||||
),
|
||||
tooltip="Model version to use for editing.",
|
||||
),
|
||||
*_postprocessing_inputs(),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
edit_instruction: str,
|
||||
model: dict,
|
||||
upscale: dict,
|
||||
remove_background: bool,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(edit_instruction, min_length=1, max_length=2560)
|
||||
tts = model["test_time_scaling"]
|
||||
ar = model["aspect_ratio"]
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path="/proxy/reve/v1/image/edit",
|
||||
method="POST",
|
||||
headers={"Accept": "image/webp"},
|
||||
),
|
||||
as_binary=True,
|
||||
price_extractor=_reve_price_extractor,
|
||||
response_header_validator=_reve_response_header_validator,
|
||||
data=ReveImageEditRequest(
|
||||
edit_instruction=edit_instruction,
|
||||
reference_image=tensor_to_base64_string(image),
|
||||
aspect_ratio=ar if ar != "auto" else None,
|
||||
version=model["model"],
|
||||
test_time_scaling=tts if tts and tts > 1 else None,
|
||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
||||
|
||||
|
||||
class ReveImageRemixNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="ReveImageRemixNode",
|
||||
display_name="Reve Image Remix",
|
||||
category="api node/image/Reve",
|
||||
description="Combine reference images with text prompts to create new images using Reve.",
|
||||
inputs=[
|
||||
IO.Autogrow.Input(
|
||||
"reference_images",
|
||||
template=IO.Autogrow.TemplatePrefix(
|
||||
IO.Image.Input("image"),
|
||||
prefix="image_",
|
||||
min=1,
|
||||
max=6,
|
||||
),
|
||||
),
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text description of the desired image. "
|
||||
"May include XML img tags to reference specific images by index, "
|
||||
"e.g. <img>0</img>, <img>1</img>, etc.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=_model_inputs(
|
||||
["reve-remix@20250915", "reve-remix-fast@20251030"],
|
||||
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
|
||||
),
|
||||
tooltip="Model version to use for remixing.",
|
||||
),
|
||||
*_postprocessing_inputs(),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(
|
||||
widgets=["model"],
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isFast := $contains(widgets.model, "fast");
|
||||
$base := $isFast ? 0.01001 : 0.0572;
|
||||
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
reference_images: IO.Autogrow.Type,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
upscale: dict,
|
||||
remove_background: bool,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=2560)
|
||||
if not reference_images:
|
||||
raise ValueError("At least one reference image is required.")
|
||||
ref_base64_list = []
|
||||
for key in reference_images:
|
||||
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
|
||||
if len(ref_base64_list) > 6:
|
||||
raise ValueError("Maximum 6 reference images are allowed.")
|
||||
tts = model["test_time_scaling"]
|
||||
ar = model["aspect_ratio"]
|
||||
response = await sync_op_raw(
|
||||
cls,
|
||||
ApiEndpoint(
|
||||
path="/proxy/reve/v1/image/remix",
|
||||
method="POST",
|
||||
headers={"Accept": "image/webp"},
|
||||
),
|
||||
as_binary=True,
|
||||
price_extractor=_reve_price_extractor,
|
||||
response_header_validator=_reve_response_header_validator,
|
||||
data=ReveImageRemixRequest(
|
||||
prompt=prompt,
|
||||
reference_images=ref_base64_list,
|
||||
aspect_ratio=ar if ar != "auto" else None,
|
||||
version=model["model"],
|
||||
test_time_scaling=tts if tts and tts > 1 else None,
|
||||
postprocessing=_build_postprocessing(upscale, remove_background),
|
||||
),
|
||||
)
|
||||
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
|
||||
|
||||
|
||||
class ReveExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
ReveImageCreateNode,
|
||||
ReveImageEditNode,
|
||||
ReveImageRemixNode,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> ReveExtension:
|
||||
return ReveExtension()
|
||||
@ -67,6 +67,7 @@ class _RequestConfig:
|
||||
progress_origin_ts: float | None = None
|
||||
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None
|
||||
response_header_validator: Callable[[dict[str, str]], None] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -202,11 +203,13 @@ async def sync_op_raw(
|
||||
monitor_progress: bool = True,
|
||||
max_retries_on_rate_limit: int = 16,
|
||||
is_rate_limited: Callable[[int, Any], bool] | None = None,
|
||||
response_header_validator: Callable[[dict[str, str]], None] | None = None,
|
||||
) -> dict[str, Any] | bytes:
|
||||
"""
|
||||
Make a single network request.
|
||||
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
|
||||
- If as_binary=True: returns bytes.
|
||||
- response_header_validator: optional callback receiving response headers dict
|
||||
"""
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump(exclude_none=True)
|
||||
@ -232,6 +235,7 @@ async def sync_op_raw(
|
||||
price_extractor=price_extractor,
|
||||
max_retries_on_rate_limit=max_retries_on_rate_limit,
|
||||
is_rate_limited=is_rate_limited,
|
||||
response_header_validator=response_header_validator,
|
||||
)
|
||||
return await _request_base(cfg, expect_binary=as_binary)
|
||||
|
||||
@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
|
||||
)
|
||||
bytes_payload = bytes(buff)
|
||||
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
|
||||
if cfg.price_extractor:
|
||||
with contextlib.suppress(Exception):
|
||||
extracted_price = cfg.price_extractor(resp_headers)
|
||||
if cfg.response_header_validator:
|
||||
cfg.response_header_validator(resp_headers)
|
||||
operation_succeeded = True
|
||||
final_elapsed_seconds = int(time.monotonic() - start_time)
|
||||
request_logger.log_request_response(
|
||||
@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
request_method=method,
|
||||
request_url=url,
|
||||
response_status_code=resp.status,
|
||||
response_headers=dict(resp.headers),
|
||||
response_headers=resp_headers,
|
||||
response_content=bytes_payload,
|
||||
)
|
||||
return bytes_payload
|
||||
|
||||
138
comfy_execution/cache_provider.py
Normal file
138
comfy_execution/cache_provider.py
Normal file
@ -0,0 +1,138 @@
|
||||
from typing import Any, Optional, Tuple, List
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
|
||||
# Public types — source of truth is comfy_api.latest._caching
|
||||
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_providers: List[CacheProvider] = []
|
||||
_providers_lock = threading.Lock()
|
||||
_providers_snapshot: Tuple[CacheProvider, ...] = ()
|
||||
|
||||
|
||||
def register_cache_provider(provider: CacheProvider) -> None:
|
||||
"""Register an external cache provider. Providers are called in registration order."""
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
if provider in _providers:
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
||||
return
|
||||
_providers.append(provider)
|
||||
_providers_snapshot = tuple(_providers)
|
||||
_logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
|
||||
|
||||
|
||||
def unregister_cache_provider(provider: CacheProvider) -> None:
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
try:
|
||||
_providers.remove(provider)
|
||||
_providers_snapshot = tuple(_providers)
|
||||
_logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
|
||||
except ValueError:
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
|
||||
|
||||
|
||||
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
|
||||
return _providers_snapshot
|
||||
|
||||
|
||||
def _has_cache_providers() -> bool:
|
||||
return bool(_providers_snapshot)
|
||||
|
||||
|
||||
def _clear_cache_providers() -> None:
|
||||
global _providers_snapshot
|
||||
with _providers_lock:
|
||||
_providers.clear()
|
||||
_providers_snapshot = ()
|
||||
|
||||
|
||||
def _canonicalize(obj: Any) -> Any:
|
||||
# Convert to canonical JSON-serializable form with deterministic ordering.
|
||||
# Frozensets have non-deterministic iteration order between Python sessions.
|
||||
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
|
||||
# _serialize_cache_key returns None and external caching is skipped.
|
||||
if isinstance(obj, frozenset):
|
||||
return ("__frozenset__", sorted(
|
||||
[_canonicalize(item) for item in obj],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
))
|
||||
elif isinstance(obj, set):
|
||||
return ("__set__", sorted(
|
||||
[_canonicalize(item) for item in obj],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
))
|
||||
elif isinstance(obj, tuple):
|
||||
return ("__tuple__", [_canonicalize(item) for item in obj])
|
||||
elif isinstance(obj, list):
|
||||
return [_canonicalize(item) for item in obj]
|
||||
elif isinstance(obj, dict):
|
||||
return {"__dict__": sorted(
|
||||
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
|
||||
key=lambda x: json.dumps(x, sort_keys=True)
|
||||
)}
|
||||
elif isinstance(obj, (int, float, str, bool, type(None))):
|
||||
return (type(obj).__name__, obj)
|
||||
elif isinstance(obj, bytes):
|
||||
return ("__bytes__", obj.hex())
|
||||
else:
|
||||
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
|
||||
|
||||
|
||||
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
|
||||
# Returns deterministic SHA256 hex digest, or None on failure.
|
||||
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
|
||||
try:
|
||||
canonical = _canonicalize(cache_key)
|
||||
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
|
||||
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
|
||||
except Exception as e:
|
||||
_logger.warning(f"Failed to serialize cache key: {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _contains_self_unequal(obj: Any) -> bool:
|
||||
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
|
||||
# never hit locally, but serialized form would match externally. Skip these.
|
||||
try:
|
||||
if not (obj == obj):
|
||||
return True
|
||||
except Exception:
|
||||
return True
|
||||
if isinstance(obj, (frozenset, tuple, list, set)):
|
||||
return any(_contains_self_unequal(item) for item in obj)
|
||||
if isinstance(obj, dict):
|
||||
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
|
||||
if hasattr(obj, 'value'):
|
||||
return _contains_self_unequal(obj.value)
|
||||
return False
|
||||
|
||||
|
||||
def _estimate_value_size(value: CacheValue) -> int:
|
||||
try:
|
||||
import torch
|
||||
except ImportError:
|
||||
return 0
|
||||
|
||||
total = 0
|
||||
|
||||
def estimate(obj):
|
||||
nonlocal total
|
||||
if isinstance(obj, torch.Tensor):
|
||||
total += obj.numel() * obj.element_size()
|
||||
elif isinstance(obj, dict):
|
||||
for v in obj.values():
|
||||
estimate(v)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
for item in obj:
|
||||
estimate(item)
|
||||
|
||||
for output in value.outputs:
|
||||
estimate(output)
|
||||
return total
|
||||
@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import bisect
|
||||
import gc
|
||||
import itertools
|
||||
@ -147,13 +148,15 @@ class CacheKeySetInputSignature(CacheKeySet):
|
||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||
|
||||
class BasicCache:
|
||||
def __init__(self, key_class):
|
||||
def __init__(self, key_class, enable_providers=False):
|
||||
self.key_class = key_class
|
||||
self.initialized = False
|
||||
self.enable_providers = enable_providers
|
||||
self.dynprompt: DynamicPrompt
|
||||
self.cache_key_set: CacheKeySet
|
||||
self.cache = {}
|
||||
self.subcaches = {}
|
||||
self._pending_store_tasks: set = set()
|
||||
|
||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||
self.dynprompt = dynprompt
|
||||
@ -196,18 +199,138 @@ class BasicCache:
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
def _get_immediate(self, node_id):
|
||||
def get_local(self, node_id):
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
else:
|
||||
return None
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
async def _set_immediate(self, node_id, value):
|
||||
assert self.initialized
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
self.cache[cache_key] = value
|
||||
|
||||
await self._notify_providers_store(node_id, cache_key, value)
|
||||
|
||||
async def _get_immediate(self, node_id):
|
||||
if not self.initialized:
|
||||
return None
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
|
||||
if cache_key in self.cache:
|
||||
return self.cache[cache_key]
|
||||
|
||||
external_result = await self._check_providers_lookup(node_id, cache_key)
|
||||
if external_result is not None:
|
||||
self.cache[cache_key] = external_result
|
||||
return external_result
|
||||
|
||||
return None
|
||||
|
||||
async def _notify_providers_store(self, node_id, cache_key, value):
|
||||
from comfy_execution.cache_provider import (
|
||||
_has_cache_providers, _get_cache_providers,
|
||||
CacheValue, _contains_self_unequal, _logger
|
||||
)
|
||||
|
||||
if not self.enable_providers:
|
||||
return
|
||||
if not _has_cache_providers():
|
||||
return
|
||||
if not self._is_external_cacheable_value(value):
|
||||
return
|
||||
if _contains_self_unequal(cache_key):
|
||||
return
|
||||
|
||||
context = self._build_context(node_id, cache_key)
|
||||
if context is None:
|
||||
return
|
||||
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
|
||||
|
||||
for provider in _get_cache_providers():
|
||||
try:
|
||||
if provider.should_cache(context, cache_value):
|
||||
task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
|
||||
self._pending_store_tasks.add(task)
|
||||
task.add_done_callback(self._pending_store_tasks.discard)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
|
||||
|
||||
@staticmethod
|
||||
async def _safe_provider_store(provider, context, cache_value):
|
||||
from comfy_execution.cache_provider import _logger
|
||||
try:
|
||||
await provider.on_store(context, cache_value)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
|
||||
|
||||
async def _check_providers_lookup(self, node_id, cache_key):
|
||||
from comfy_execution.cache_provider import (
|
||||
_has_cache_providers, _get_cache_providers,
|
||||
CacheValue, _contains_self_unequal, _logger
|
||||
)
|
||||
|
||||
if not self.enable_providers:
|
||||
return None
|
||||
if not _has_cache_providers():
|
||||
return None
|
||||
if _contains_self_unequal(cache_key):
|
||||
return None
|
||||
|
||||
context = self._build_context(node_id, cache_key)
|
||||
if context is None:
|
||||
return None
|
||||
|
||||
for provider in _get_cache_providers():
|
||||
try:
|
||||
if not provider.should_cache(context):
|
||||
continue
|
||||
result = await provider.on_lookup(context)
|
||||
if result is not None:
|
||||
if not isinstance(result, CacheValue):
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
|
||||
continue
|
||||
if not isinstance(result.outputs, (list, tuple)):
|
||||
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
|
||||
continue
|
||||
from execution import CacheEntry
|
||||
return CacheEntry(ui=result.ui, outputs=list(result.outputs))
|
||||
except Exception as e:
|
||||
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
|
||||
|
||||
return None
|
||||
|
||||
def _is_external_cacheable_value(self, value):
|
||||
return hasattr(value, 'outputs') and hasattr(value, 'ui')
|
||||
|
||||
def _get_class_type(self, node_id):
|
||||
if not self.initialized or not self.dynprompt:
|
||||
return ''
|
||||
try:
|
||||
return self.dynprompt.get_node(node_id).get('class_type', '')
|
||||
except Exception:
|
||||
return ''
|
||||
|
||||
def _build_context(self, node_id, cache_key):
|
||||
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
|
||||
try:
|
||||
cache_key_hash = _serialize_cache_key(cache_key)
|
||||
if cache_key_hash is None:
|
||||
return None
|
||||
return CacheContext(
|
||||
node_id=node_id,
|
||||
class_type=self._get_class_type(node_id),
|
||||
cache_key_hash=cache_key_hash,
|
||||
)
|
||||
except Exception as e:
|
||||
_logger.warning(f"Failed to build cache context for node {node_id}: {e}")
|
||||
return None
|
||||
|
||||
async def _ensure_subcache(self, node_id, children_ids):
|
||||
@ -236,8 +359,8 @@ class BasicCache:
|
||||
return result
|
||||
|
||||
class HierarchicalCache(BasicCache):
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class)
|
||||
def __init__(self, key_class, enable_providers=False):
|
||||
super().__init__(key_class, enable_providers=enable_providers)
|
||||
|
||||
def _get_cache_for(self, node_id):
|
||||
assert self.dynprompt is not None
|
||||
@ -257,16 +380,27 @@ class HierarchicalCache(BasicCache):
|
||||
return None
|
||||
return cache
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return cache._get_immediate(node_id)
|
||||
return await cache._get_immediate(node_id)
|
||||
|
||||
def set(self, node_id, value):
|
||||
def get_local(self, node_id):
|
||||
cache = self._get_cache_for(node_id)
|
||||
if cache is None:
|
||||
return None
|
||||
return BasicCache.get_local(cache, node_id)
|
||||
|
||||
async def set(self, node_id, value):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
cache._set_immediate(node_id, value)
|
||||
await cache._set_immediate(node_id, value)
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
cache = self._get_cache_for(node_id)
|
||||
assert cache is not None
|
||||
BasicCache.set_local(cache, node_id, value)
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
cache = self._get_cache_for(node_id)
|
||||
@ -287,18 +421,24 @@ class NullCache:
|
||||
def poll(self, **kwargs):
|
||||
pass
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
return None
|
||||
|
||||
def set(self, node_id, value):
|
||||
def get_local(self, node_id):
|
||||
return None
|
||||
|
||||
async def set(self, node_id, value):
|
||||
pass
|
||||
|
||||
def set_local(self, node_id, value):
|
||||
pass
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
return self
|
||||
|
||||
class LRUCache(BasicCache):
|
||||
def __init__(self, key_class, max_size=100):
|
||||
super().__init__(key_class)
|
||||
def __init__(self, key_class, max_size=100, enable_providers=False):
|
||||
super().__init__(key_class, enable_providers=enable_providers)
|
||||
self.max_size = max_size
|
||||
self.min_generation = 0
|
||||
self.generation = 0
|
||||
@ -322,18 +462,18 @@ class LRUCache(BasicCache):
|
||||
del self.children[key]
|
||||
self._clean_subcaches()
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
self._mark_used(node_id)
|
||||
return self._get_immediate(node_id)
|
||||
return await self._get_immediate(node_id)
|
||||
|
||||
def _mark_used(self, node_id):
|
||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||
if cache_key is not None:
|
||||
self.used_generation[cache_key] = self.generation
|
||||
|
||||
def set(self, node_id, value):
|
||||
async def set(self, node_id, value):
|
||||
self._mark_used(node_id)
|
||||
return self._set_immediate(node_id, value)
|
||||
return await self._set_immediate(node_id, value)
|
||||
|
||||
async def ensure_subcache_for(self, node_id, children_ids):
|
||||
# Just uses subcaches for tracking 'live' nodes
|
||||
@ -366,20 +506,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
||||
|
||||
class RAMPressureCache(LRUCache):
|
||||
|
||||
def __init__(self, key_class):
|
||||
super().__init__(key_class, 0)
|
||||
def __init__(self, key_class, enable_providers=False):
|
||||
super().__init__(key_class, 0, enable_providers=enable_providers)
|
||||
self.timestamps = {}
|
||||
|
||||
def clean_unused(self):
|
||||
self._clean_subcaches()
|
||||
|
||||
def set(self, node_id, value):
|
||||
async def set(self, node_id, value):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
super().set(node_id, value)
|
||||
await super().set(node_id, value)
|
||||
|
||||
def get(self, node_id):
|
||||
async def get(self, node_id):
|
||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||
return super().get(node_id)
|
||||
return await super().get(node_id)
|
||||
|
||||
def poll(self, ram_headroom):
|
||||
def _ram_gb():
|
||||
|
||||
@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
|
||||
self.execution_cache_listeners = {}
|
||||
|
||||
def is_cached(self, node_id):
|
||||
return self.output_cache.get(node_id) is not None
|
||||
return self.output_cache.get_local(node_id) is not None
|
||||
|
||||
def cache_link(self, from_node_id, to_node_id):
|
||||
if to_node_id not in self.execution_cache:
|
||||
self.execution_cache[to_node_id] = {}
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
|
||||
if from_node_id not in self.execution_cache_listeners:
|
||||
self.execution_cache_listeners[from_node_id] = set()
|
||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||
@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
|
||||
if value is None:
|
||||
return None
|
||||
#Write back to the main cache on touch.
|
||||
self.output_cache.set(from_node_id, value)
|
||||
self.output_cache.set_local(from_node_id, value)
|
||||
return value
|
||||
|
||||
def cache_update(self, node_id, value):
|
||||
|
||||
@ -19,6 +19,7 @@ class EmptyLatentAudio(IO.ComfyNode):
|
||||
node_id="EmptyLatentAudio",
|
||||
display_name="Empty Latent Audio",
|
||||
category="latent/audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||
IO.Int.Input(
|
||||
@ -185,6 +186,7 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
search_aliases=["export mp3"],
|
||||
display_name="Save Audio (MP3)",
|
||||
category="audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
|
||||
@ -6,6 +6,7 @@ import comfy.model_management
|
||||
import torch
|
||||
import math
|
||||
import nodes
|
||||
import comfy.ldm.flux.math
|
||||
|
||||
class CLIPTextEncodeFlux(io.ComfyNode):
|
||||
@classmethod
|
||||
@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode):
|
||||
sigmas = get_schedule(steps, round(seq_len))
|
||||
return io.NodeOutput(sigmas)
|
||||
|
||||
class KV_Attn_Input:
|
||||
def __init__(self):
|
||||
self.cache = {}
|
||||
|
||||
def __call__(self, q, k, v, extra_options, **kwargs):
|
||||
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
|
||||
if len(reference_image_num_tokens) == 0:
|
||||
return {}
|
||||
|
||||
ref_toks = sum(reference_image_num_tokens)
|
||||
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
||||
if cache_key in self.cache:
|
||||
kk, vv = self.cache[cache_key]
|
||||
self.set_cache = False
|
||||
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
||||
|
||||
self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
|
||||
self.set_cache = True
|
||||
return {"q": q, "k": k, "v": v}
|
||||
|
||||
def cleanup(self):
|
||||
self.cache = {}
|
||||
|
||||
|
||||
class FluxKVCache(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="FluxKVCache",
|
||||
display_name="Flux KV Cache",
|
||||
description="Enables KV Cache optimization for reference images on Flux family models.",
|
||||
category="",
|
||||
is_experimental=True,
|
||||
inputs=[
|
||||
io.Model.Input("model", tooltip="The model to use KV Cache on."),
|
||||
],
|
||||
outputs=[
|
||||
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
|
||||
m = model.clone()
|
||||
input_patch_obj = KV_Attn_Input()
|
||||
|
||||
def model_input_patch(inputs):
|
||||
if len(input_patch_obj.cache) > 0:
|
||||
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
|
||||
if ref_image_tokens > 0:
|
||||
img = inputs["img"]
|
||||
inputs["img"] = img[:, :-ref_image_tokens]
|
||||
return inputs
|
||||
|
||||
m.set_model_attn1_patch(input_patch_obj)
|
||||
m.set_model_post_input_patch(model_input_patch)
|
||||
if hasattr(model.model.diffusion_model, "params"):
|
||||
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
|
||||
else:
|
||||
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
|
||||
|
||||
return io.NodeOutput(m)
|
||||
|
||||
class FluxExtension(ComfyExtension):
|
||||
@override
|
||||
@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension):
|
||||
FluxKontextMultiReferenceLatentMethod,
|
||||
EmptyFlux2LatentImage,
|
||||
Flux2Scheduler,
|
||||
FluxKVCache,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -14,6 +14,7 @@ class ImageCompare(IO.ComfyNode):
|
||||
display_name="Image Compare",
|
||||
description="Compares two images side by side with a slider.",
|
||||
category="image",
|
||||
essentials_category="Image Tools",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
|
||||
@ -58,6 +58,7 @@ class ImageCropV2(IO.ComfyNode):
|
||||
search_aliases=["trim"],
|
||||
display_name="Image Crop",
|
||||
category="image/transform",
|
||||
essentials_category="Image Tools",
|
||||
inputs=[
|
||||
IO.Image.Input("image"),
|
||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||
|
||||
127
comfy_extras/nodes_painter.py
Normal file
127
comfy_extras/nodes_painter.py
Normal file
@ -0,0 +1,127 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io, UI
|
||||
from typing_extensions import override
|
||||
|
||||
|
||||
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
|
||||
hex_color = hex_color.lstrip("#")
|
||||
if len(hex_color) != 6:
|
||||
return (0.0, 0.0, 0.0)
|
||||
r = int(hex_color[0:2], 16) / 255.0
|
||||
g = int(hex_color[2:4], 16) / 255.0
|
||||
b = int(hex_color[4:6], 16) / 255.0
|
||||
return (r, g, b)
|
||||
|
||||
|
||||
class PainterNode(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="Painter",
|
||||
display_name="Painter",
|
||||
category="image",
|
||||
inputs=[
|
||||
io.Image.Input(
|
||||
"image",
|
||||
optional=True,
|
||||
tooltip="Optional base image to paint over",
|
||||
),
|
||||
io.String.Input(
|
||||
"mask",
|
||||
default="",
|
||||
socketless=True,
|
||||
extra_dict={"widgetType": "PAINTER", "image_upload": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"width",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.Int.Input(
|
||||
"height",
|
||||
default=512,
|
||||
min=64,
|
||||
max=4096,
|
||||
step=64,
|
||||
socketless=True,
|
||||
extra_dict={"hidden": True},
|
||||
),
|
||||
io.Color.Input("bg_color", default="#000000"),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("IMAGE"),
|
||||
io.Mask.Output("MASK"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
|
||||
if image is not None:
|
||||
base_image = image[:1]
|
||||
h, w = base_image.shape[1], base_image.shape[2]
|
||||
else:
|
||||
h, w = height, width
|
||||
r, g, b = hex_to_rgb(bg_color)
|
||||
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
|
||||
base_image[0, :, :, 0] = r
|
||||
base_image[0, :, :, 1] = g
|
||||
base_image[0, :, :, 2] = b
|
||||
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
painter_img = node_helpers.pillow(Image.open, mask_path)
|
||||
painter_img = painter_img.convert("RGBA")
|
||||
|
||||
if painter_img.size != (w, h):
|
||||
painter_img = painter_img.resize((w, h), Image.LANCZOS)
|
||||
|
||||
painter_np = np.array(painter_img).astype(np.float32) / 255.0
|
||||
painter_rgb = painter_np[:, :, :3]
|
||||
painter_alpha = painter_np[:, :, 3:4]
|
||||
|
||||
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
|
||||
|
||||
base_np = base_image[0].cpu().numpy()
|
||||
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
|
||||
out_image = torch.from_numpy(composited).unsqueeze(0)
|
||||
else:
|
||||
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
|
||||
out_image = base_image
|
||||
|
||||
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
|
||||
|
||||
@classmethod
|
||||
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
|
||||
if mask and mask.strip():
|
||||
mask_path = folder_paths.get_annotated_filepath(mask)
|
||||
if os.path.exists(mask_path):
|
||||
m = hashlib.sha256()
|
||||
with open(mask_path, "rb") as f:
|
||||
m.update(f.read())
|
||||
return m.digest().hex()
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
class PainterExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self):
|
||||
return [PainterNode]
|
||||
|
||||
|
||||
async def comfy_entrypoint():
|
||||
return PainterExtension()
|
||||
@ -21,6 +21,7 @@ class Blend(io.ComfyNode):
|
||||
node_id="ImageBlend",
|
||||
display_name="Image Blend",
|
||||
category="image/postprocessing",
|
||||
essentials_category="Image Tools",
|
||||
inputs=[
|
||||
io.Image.Input("image1"),
|
||||
io.Image.Input("image2"),
|
||||
|
||||
@ -15,6 +15,7 @@ import comfy.sampler_helpers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
import node_helpers
|
||||
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
training_dtype=torch.bfloat16,
|
||||
real_dataset=None,
|
||||
bucket_latents=None,
|
||||
use_grad_scaler=False,
|
||||
):
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = optimizer
|
||||
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self.bucket_latents: list[torch.Tensor] | None = (
|
||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||
)
|
||||
# GradScaler for fp16 training
|
||||
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
||||
# Precompute bucket offsets and weights for sampling
|
||||
if bucket_latents is not None:
|
||||
self._init_bucket_data(bucket_latents)
|
||||
@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
batch_sigmas.requires_grad_(True),
|
||||
**batch_extra_args,
|
||||
)
|
||||
loss = self.loss_fn(x0_pred, x0)
|
||||
loss = self.loss_fn(x0_pred.float(), x0.float())
|
||||
if bwd:
|
||||
bwd_loss = loss / self.grad_acc
|
||||
bwd_loss.backward()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.scale(bwd_loss).backward()
|
||||
else:
|
||||
bwd_loss.backward()
|
||||
return loss
|
||||
|
||||
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||
@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
)
|
||||
total_loss += loss
|
||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||
total_loss.backward()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.scale(total_loss).backward()
|
||||
else:
|
||||
total_loss.backward()
|
||||
if self.loss_callback:
|
||||
self.loss_callback(total_loss.item())
|
||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||
@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
|
||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||
|
||||
if (i + 1) % self.grad_acc == 0:
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.unscale_(self.optimizer)
|
||||
for param_groups in self.optimizer.param_groups:
|
||||
for param in param_groups["params"]:
|
||||
if param.grad is None:
|
||||
continue
|
||||
param.grad.data = param.grad.data.to(param.data.dtype)
|
||||
self.optimizer.step()
|
||||
if self.grad_scaler is not None:
|
||||
self.grad_scaler.step(self.optimizer)
|
||||
self.grad_scaler.update()
|
||||
else:
|
||||
self.optimizer.step()
|
||||
self.optimizer.zero_grad()
|
||||
ui_pbar.update(1)
|
||||
torch.cuda.empty_cache()
|
||||
@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
|
||||
),
|
||||
io.Combo.Input(
|
||||
"training_dtype",
|
||||
options=["bf16", "fp32"],
|
||||
options=["bf16", "fp32", "none"],
|
||||
default="bf16",
|
||||
tooltip="The dtype to use for training.",
|
||||
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"lora_dtype",
|
||||
@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
io.Boolean.Input(
|
||||
"offloading",
|
||||
default=False,
|
||||
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
||||
tooltip="Offload model weights to CPU during training to save GPU memory.",
|
||||
),
|
||||
io.Combo.Input(
|
||||
"existing_lora",
|
||||
@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
|
||||
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
use_grad_scaler = False
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
else:
|
||||
# Detect model's native dtype for autocast
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||
)
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
|
||||
if mp.is_dynamic():
|
||||
if not bypass_mode:
|
||||
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
||||
bypass_mode = True
|
||||
offloading = True
|
||||
elif offloading:
|
||||
if not bypass_mode:
|
||||
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, dtype, bucket_mode
|
||||
latents, latents_dtype, bucket_mode
|
||||
)
|
||||
|
||||
# Validate and expand conditioning
|
||||
@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
bucket_latents=latents,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
)
|
||||
else:
|
||||
train_sampler = TrainSampler(
|
||||
@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
|
||||
seed=seed,
|
||||
training_dtype=dtype,
|
||||
real_dataset=latents if multi_res else None,
|
||||
use_grad_scaler=use_grad_scaler,
|
||||
)
|
||||
|
||||
# Setup guider
|
||||
@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
|
||||
io.Int.Input(
|
||||
"steps",
|
||||
optional=True,
|
||||
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
||||
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
|
||||
),
|
||||
],
|
||||
outputs=[],
|
||||
|
||||
@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode):
|
||||
pbar = comfy.utils.ProgressBar(steps)
|
||||
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
|
||||
oom = False
|
||||
except model_management.OOM_EXCEPTION as e:
|
||||
except Exception as e:
|
||||
model_management.raise_non_oom(e)
|
||||
tile //= 2
|
||||
if tile < 128:
|
||||
raise e
|
||||
|
||||
@ -1,3 +1,3 @@
|
||||
# This file is automatically generated by the build process when version is
|
||||
# updated in pyproject.toml.
|
||||
__version__ = "0.16.3"
|
||||
__version__ = "0.17.0"
|
||||
|
||||
149
execution.py
149
execution.py
@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io, _io
|
||||
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
||||
|
||||
|
||||
class ExecutionResult(Enum):
|
||||
@ -126,15 +127,15 @@ class CacheSet:
|
||||
|
||||
# Performs like the old cache -- dump data ASAP
|
||||
def init_classic_cache(self):
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_lru_cache(self, cache_size):
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_ram_cache(self, min_headroom):
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True)
|
||||
self.objects = HierarchicalCache(CacheKeySetID)
|
||||
|
||||
def init_null_cache(self):
|
||||
@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
cached = caches.outputs.get(unique_id)
|
||||
cached = await caches.outputs.get(unique_id)
|
||||
if cached is not None:
|
||||
if server.client_id is not None:
|
||||
cached_ui = cached.ui or {}
|
||||
@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
server.last_node_id = display_node_id
|
||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||
|
||||
obj = caches.objects.get(unique_id)
|
||||
obj = await caches.objects.get(unique_id)
|
||||
if obj is None:
|
||||
obj = class_def()
|
||||
caches.objects.set(unique_id, obj)
|
||||
await caches.objects.set(unique_id, obj)
|
||||
|
||||
if issubclass(class_def, _ComfyNodeInternal):
|
||||
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||
@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
|
||||
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
||||
execution_list.cache_update(unique_id, cache_entry)
|
||||
caches.outputs.set(unique_id, cache_entry)
|
||||
await caches.outputs.set(unique_id, cache_entry)
|
||||
|
||||
except comfy.model_management.InterruptProcessingException as iex:
|
||||
logging.info("Processing interrupted")
|
||||
@ -612,7 +613,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
logging.error(traceback.format_exc())
|
||||
tips = ""
|
||||
|
||||
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
|
||||
if comfy.model_management.is_oom(ex):
|
||||
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
|
||||
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
|
||||
logging.error("Got an OOM, unloading all loaded models.")
|
||||
@ -684,6 +685,19 @@ class PromptExecutor:
|
||||
}
|
||||
self.add_message("execution_error", mes, broadcast=False)
|
||||
|
||||
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
|
||||
if not _has_cache_providers():
|
||||
return
|
||||
|
||||
for provider in _get_cache_providers():
|
||||
try:
|
||||
if event == "start":
|
||||
provider.on_prompt_start(prompt_id)
|
||||
elif event == "end":
|
||||
provider.on_prompt_end(prompt_id)
|
||||
except Exception as e:
|
||||
_cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
|
||||
|
||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||
|
||||
@ -700,66 +714,75 @@ class PromptExecutor:
|
||||
self.status_messages = []
|
||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
self._notify_prompt_lifecycle("start", prompt_id)
|
||||
|
||||
cached_nodes = []
|
||||
for node_id in prompt:
|
||||
if self.caches.outputs.get(node_id) is not None:
|
||||
cached_nodes.append(node_id)
|
||||
try:
|
||||
with torch.inference_mode():
|
||||
dynamic_prompt = DynamicPrompt(prompt)
|
||||
reset_progress_state(prompt_id, dynamic_prompt)
|
||||
add_progress_handler(WebUIProgressHandler(self.server))
|
||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||
for cache in self.caches.all:
|
||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||
cache.clean_unused()
|
||||
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
node_ids = list(prompt.keys())
|
||||
cache_results = await asyncio.gather(
|
||||
*(self.caches.outputs.get(node_id) for node_id in node_ids)
|
||||
)
|
||||
cached_nodes = [
|
||||
node_id for node_id, result in zip(node_ids, cache_results)
|
||||
if result is not None
|
||||
]
|
||||
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
comfy.model_management.cleanup_models_gc()
|
||||
self.add_message("execution_cached",
|
||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||
broadcast=False)
|
||||
pending_subgraph_results = {}
|
||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||
ui_node_outputs = {}
|
||||
executed = set()
|
||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||
current_outputs = self.caches.outputs.all_node_ids()
|
||||
for node_id in list(execute_outputs):
|
||||
execution_list.add_node(node_id)
|
||||
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
while not execution_list.is_empty():
|
||||
node_id, error, ex = await execution_list.stage_node_execution()
|
||||
if error is not None:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for node_id, ui_info in ui_node_outputs.items():
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
assert node_id is not None, "Node ID should not be None at this point"
|
||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||
self.success = result != ExecutionResult.FAILURE
|
||||
if result == ExecutionResult.FAILURE:
|
||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||
break
|
||||
elif result == ExecutionResult.PENDING:
|
||||
execution_list.unstage_node_execution()
|
||||
else: # result == ExecutionResult.SUCCESS:
|
||||
execution_list.complete_node_execution()
|
||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||
else:
|
||||
# Only execute when the while-loop ends without break
|
||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||
|
||||
ui_outputs = {}
|
||||
meta_outputs = {}
|
||||
for node_id, ui_info in ui_node_outputs.items():
|
||||
ui_outputs[node_id] = ui_info["output"]
|
||||
meta_outputs[node_id] = ui_info["meta"]
|
||||
self.history_result = {
|
||||
"outputs": ui_outputs,
|
||||
"meta": meta_outputs,
|
||||
}
|
||||
self.server.last_node_id = None
|
||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||
comfy.model_management.unload_all_models()
|
||||
finally:
|
||||
self._notify_prompt_lifecycle("end", prompt_id)
|
||||
|
||||
|
||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||
|
||||
61
main.py
61
main.py
@ -3,18 +3,22 @@ comfy.options.enable_args_parsing()
|
||||
|
||||
import os
|
||||
import importlib.util
|
||||
import shutil
|
||||
import importlib.metadata
|
||||
import folder_paths
|
||||
import time
|
||||
from comfy.cli_args import args, enables_dynamic_vram
|
||||
from app.logger import setup_logger
|
||||
from app.assets.scanner import seed_assets
|
||||
import itertools
|
||||
import utils.extra_config
|
||||
from utils.mime_types import init_mime_types
|
||||
import faulthandler
|
||||
import logging
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
from comfy_api import feature_flags
|
||||
from app.database.db import init_db, dependencies_available
|
||||
|
||||
if __name__ == "__main__":
|
||||
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
|
||||
@ -23,6 +27,8 @@ if __name__ == "__main__":
|
||||
|
||||
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
|
||||
|
||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
if enables_dynamic_vram():
|
||||
@ -62,8 +68,15 @@ if __name__ == "__main__":
|
||||
|
||||
|
||||
def handle_comfyui_manager_unavailable():
|
||||
if not args.windows_standalone_build:
|
||||
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
|
||||
manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt")
|
||||
uv_available = shutil.which("uv") is not None
|
||||
|
||||
pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}"
|
||||
msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}"
|
||||
if uv_available:
|
||||
msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}"
|
||||
msg += "\n"
|
||||
logging.warning(msg)
|
||||
args.enable_manager = False
|
||||
|
||||
|
||||
@ -161,6 +174,7 @@ def execute_prestartup_script():
|
||||
logging.info("")
|
||||
|
||||
apply_custom_paths()
|
||||
init_mime_types()
|
||||
|
||||
if args.enable_manager:
|
||||
comfyui_manager.prestartup()
|
||||
@ -170,7 +184,6 @@ execute_prestartup_script()
|
||||
|
||||
# Main code
|
||||
import asyncio
|
||||
import shutil
|
||||
import threading
|
||||
import gc
|
||||
|
||||
@ -179,6 +192,7 @@ if 'torch' in sys.modules:
|
||||
|
||||
|
||||
import comfy.utils
|
||||
from app.assets.seeder import asset_seeder
|
||||
|
||||
import execution
|
||||
import server
|
||||
@ -192,8 +206,8 @@ import hook_breaker_ac10a0
|
||||
import comfy.memory_management
|
||||
import comfy.model_patcher
|
||||
|
||||
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
|
||||
if comfy.model_management.torch_version_numeric < (2, 8):
|
||||
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
||||
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||
if args.verbose == 'DEBUG':
|
||||
@ -258,6 +272,7 @@ def prompt_worker(q, server_instance):
|
||||
for k in sensitive:
|
||||
extra_data[k] = sensitive[k]
|
||||
|
||||
asset_seeder.pause()
|
||||
e.execute(item[2], prompt_id, extra_data, item[4])
|
||||
need_gc = True
|
||||
|
||||
@ -302,6 +317,7 @@ def prompt_worker(q, server_instance):
|
||||
last_gc_collect = current_time
|
||||
need_gc = False
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
asset_seeder.resume()
|
||||
|
||||
|
||||
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
|
||||
@ -352,12 +368,29 @@ def cleanup_temp():
|
||||
|
||||
def setup_database():
|
||||
try:
|
||||
from app.database.db import init_db, dependencies_available
|
||||
if dependencies_available():
|
||||
init_db()
|
||||
if not args.disable_assets_autoscan:
|
||||
seed_assets(["models"], enable_logging=True)
|
||||
if args.enable_assets:
|
||||
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
|
||||
logging.info("Background asset scan initiated for models, input, output")
|
||||
except Exception as e:
|
||||
if "database is locked" in str(e):
|
||||
logging.error(
|
||||
"Database is locked. Another ComfyUI process is already using this database.\n"
|
||||
"To resolve this, specify a separate database file for this instance:\n"
|
||||
" --database-url sqlite:///path/to/another.db"
|
||||
)
|
||||
sys.exit(1)
|
||||
if args.enable_assets:
|
||||
logging.error(
|
||||
f"Failed to initialize database: {e}\n"
|
||||
"The --enable-assets flag requires a working database connection.\n"
|
||||
"To resolve this, try one of the following:\n"
|
||||
" 1. Install the latest requirements: pip install -r requirements.txt\n"
|
||||
" 2. Specify an alternative database URL: --database-url sqlite:///path/to/your.db\n"
|
||||
" 3. Use an in-memory database: --database-url sqlite:///:memory:"
|
||||
)
|
||||
sys.exit(1)
|
||||
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
|
||||
|
||||
|
||||
@ -429,6 +462,11 @@ if __name__ == "__main__":
|
||||
# Running directly, just start ComfyUI.
|
||||
logging.info("Python version: {}".format(sys.version))
|
||||
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
|
||||
for package in ("comfy-aimdo", "comfy-kitchen"):
|
||||
try:
|
||||
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
|
||||
except:
|
||||
pass
|
||||
|
||||
if sys.version_info.major == 3 and sys.version_info.minor < 10:
|
||||
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
|
||||
@ -440,5 +478,6 @@ if __name__ == "__main__":
|
||||
event_loop.run_until_complete(x)
|
||||
except KeyboardInterrupt:
|
||||
logging.info("\nStopped server")
|
||||
|
||||
cleanup_temp()
|
||||
finally:
|
||||
asset_seeder.shutdown()
|
||||
cleanup_temp()
|
||||
|
||||
@ -1 +1 @@
|
||||
comfyui_manager==4.1b1
|
||||
comfyui_manager==4.1b5
|
||||
@ -32,7 +32,7 @@ async def cache_control(
|
||||
)
|
||||
|
||||
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
||||
response.headers.setdefault("Cache-Control", "no-cache")
|
||||
response.headers.setdefault("Cache-Control", "no-store")
|
||||
return response
|
||||
|
||||
# Early return for non-image files - no cache headers needed
|
||||
|
||||
15
nodes.py
15
nodes.py
@ -81,6 +81,7 @@ class CLIPTextEncode(ComfyNodeABC):
|
||||
|
||||
|
||||
class ConditioningCombine:
|
||||
ESSENTIALS_CATEGORY = "Image Generation"
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
|
||||
@ -1211,9 +1212,6 @@ class GLIGENTextBoxApply:
|
||||
return (c, )
|
||||
|
||||
class EmptyLatentImage:
|
||||
def __init__(self):
|
||||
self.device = comfy.model_management.intermediate_device()
|
||||
|
||||
@classmethod
|
||||
def INPUT_TYPES(s):
|
||||
return {
|
||||
@ -1232,7 +1230,7 @@ class EmptyLatentImage:
|
||||
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
|
||||
|
||||
def generate(self, width, height, batch_size=1):
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
|
||||
|
||||
|
||||
@ -1724,6 +1722,8 @@ class LoadImage:
|
||||
output_masks = []
|
||||
w, h = None, None
|
||||
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
|
||||
for i in ImageSequence.Iterator(img):
|
||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||
|
||||
@ -1748,8 +1748,8 @@ class LoadImage:
|
||||
mask = 1. - torch.from_numpy(mask)
|
||||
else:
|
||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||
output_images.append(image)
|
||||
output_masks.append(mask.unsqueeze(0))
|
||||
output_images.append(image.to(dtype=dtype))
|
||||
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
||||
|
||||
if img.format == "MPO":
|
||||
break # ignore all frames except the first one for MPO format
|
||||
@ -1779,6 +1779,7 @@ class LoadImage:
|
||||
return True
|
||||
|
||||
class LoadImageMask:
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||
|
||||
_color_channels = ["alpha", "red", "green", "blue"]
|
||||
@ -1887,6 +1888,7 @@ class ImageScale:
|
||||
return (s,)
|
||||
|
||||
class ImageScaleBy:
|
||||
ESSENTIALS_CATEGORY = "Image Tools"
|
||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||
|
||||
@classmethod
|
||||
@ -2450,6 +2452,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_nag.py",
|
||||
"nodes_sdpose.py",
|
||||
"nodes_math.py",
|
||||
"nodes_painter.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
[project]
|
||||
name = "ComfyUI"
|
||||
version = "0.16.3"
|
||||
version = "0.17.0"
|
||||
readme = "README.md"
|
||||
license = { file = "LICENSE" }
|
||||
requires-python = ">=3.10"
|
||||
|
||||
@ -1,5 +1,5 @@
|
||||
comfyui-frontend-package==1.39.19
|
||||
comfyui-workflow-templates==0.9.10
|
||||
comfyui-frontend-package==1.41.20
|
||||
comfyui-workflow-templates==0.9.21
|
||||
comfyui-embedded-docs==0.4.3
|
||||
torch
|
||||
torchsde
|
||||
@ -20,11 +20,13 @@ tqdm
|
||||
psutil
|
||||
alembic
|
||||
SQLAlchemy
|
||||
filelock
|
||||
av>=14.2.0
|
||||
comfy-kitchen>=0.2.7
|
||||
comfy-aimdo>=0.2.7
|
||||
comfy-kitchen>=0.2.8
|
||||
comfy-aimdo>=0.2.12
|
||||
requests
|
||||
simpleeval>=1.0
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
#non essential dependencies:
|
||||
kornia>=0.7.1
|
||||
|
||||
100
server.py
100
server.py
@ -33,8 +33,10 @@ import node_helpers
|
||||
from comfyui_version import __version__
|
||||
from app.frontend_management import FrontendManager, parse_version
|
||||
from comfy_api.internal import _ComfyNodeInternal
|
||||
from app.assets.scanner import seed_assets
|
||||
from app.assets.api.routes import register_assets_system
|
||||
from app.assets.seeder import asset_seeder
|
||||
from app.assets.api.routes import register_assets_routes
|
||||
from app.assets.services.ingest import register_file_in_place
|
||||
from app.assets.services.asset_management import resolve_hash_to_path
|
||||
|
||||
from app.user_manager import UserManager
|
||||
from app.model_manager import ModelFileManager
|
||||
@ -197,10 +199,6 @@ class PromptServer():
|
||||
def __init__(self, loop):
|
||||
PromptServer.instance = self
|
||||
|
||||
mimetypes.init()
|
||||
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
|
||||
mimetypes.add_type('image/webp', '.webp')
|
||||
|
||||
self.user_manager = UserManager()
|
||||
self.model_file_manager = ModelFileManager()
|
||||
self.custom_node_manager = CustomNodeManager()
|
||||
@ -239,7 +237,11 @@ class PromptServer():
|
||||
else args.front_end_root
|
||||
)
|
||||
logging.info(f"[Prompt Server] web root: {self.web_root}")
|
||||
register_assets_system(self.app, self.user_manager)
|
||||
if args.enable_assets:
|
||||
register_assets_routes(self.app, self.user_manager)
|
||||
else:
|
||||
register_assets_routes(self.app)
|
||||
asset_seeder.disable()
|
||||
routes = web.RouteTableDef()
|
||||
self.routes = routes
|
||||
self.last_node_id = None
|
||||
@ -310,7 +312,7 @@ class PromptServer():
|
||||
@routes.get("/")
|
||||
async def get_root(request):
|
||||
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||||
response.headers['Cache-Control'] = 'no-cache'
|
||||
response.headers['Cache-Control'] = 'no-store, must-revalidate'
|
||||
response.headers["Pragma"] = "no-cache"
|
||||
response.headers["Expires"] = "0"
|
||||
return response
|
||||
@ -419,7 +421,24 @@ class PromptServer():
|
||||
with open(filepath, "wb") as f:
|
||||
f.write(image.file.read())
|
||||
|
||||
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
||||
resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
|
||||
|
||||
if args.enable_assets:
|
||||
try:
|
||||
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
|
||||
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
|
||||
resp["asset"] = {
|
||||
"id": result.ref.id,
|
||||
"name": result.ref.name,
|
||||
"asset_hash": result.asset.hash,
|
||||
"size": result.asset.size_bytes,
|
||||
"mime_type": result.asset.mime_type,
|
||||
"tags": result.tags,
|
||||
}
|
||||
except Exception:
|
||||
logging.warning("Failed to register uploaded image as asset", exc_info=True)
|
||||
|
||||
return web.json_response(resp)
|
||||
else:
|
||||
return web.Response(status=400)
|
||||
|
||||
@ -479,30 +498,43 @@ class PromptServer():
|
||||
async def view_image(request):
|
||||
if "filename" in request.rel_url.query:
|
||||
filename = request.rel_url.query["filename"]
|
||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
# The frontend's LoadImage combo widget uses asset_hash values
|
||||
# (e.g. "blake3:...") as widget values. When litegraph renders the
|
||||
# node preview, it constructs /view?filename=<asset_hash>, so this
|
||||
# endpoint must resolve blake3 hashes to their on-disk file paths.
|
||||
if filename.startswith("blake3:"):
|
||||
owner_id = self.user_manager.get_request_user_id(request)
|
||||
result = resolve_hash_to_path(filename, owner_id=owner_id)
|
||||
if result is None:
|
||||
return web.Response(status=404)
|
||||
file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type
|
||||
else:
|
||||
resolved_content_type = None
|
||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||
|
||||
# validation for security: prevent accessing arbitrary path
|
||||
if filename[0] == '/' or '..' in filename:
|
||||
return web.Response(status=400)
|
||||
if not filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
if output_dir is None:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
# validation for security: prevent accessing arbitrary path
|
||||
if filename[0] == '/' or '..' in filename:
|
||||
return web.Response(status=400)
|
||||
|
||||
if output_dir is None:
|
||||
return web.Response(status=400)
|
||||
if output_dir is None:
|
||||
type = request.rel_url.query.get("type", "output")
|
||||
output_dir = folder_paths.get_directory_by_type(type)
|
||||
|
||||
if "subfolder" in request.rel_url.query:
|
||||
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
||||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||
return web.Response(status=403)
|
||||
output_dir = full_output_dir
|
||||
if output_dir is None:
|
||||
return web.Response(status=400)
|
||||
|
||||
filename = os.path.basename(filename)
|
||||
file = os.path.join(output_dir, filename)
|
||||
if "subfolder" in request.rel_url.query:
|
||||
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
||||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||
return web.Response(status=403)
|
||||
output_dir = full_output_dir
|
||||
|
||||
filename = os.path.basename(filename)
|
||||
file = os.path.join(output_dir, filename)
|
||||
|
||||
if os.path.isfile(file):
|
||||
if 'preview' in request.rel_url.query:
|
||||
@ -562,8 +594,13 @@ class PromptServer():
|
||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||
else:
|
||||
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
||||
# Use the content type from asset resolution if available,
|
||||
# otherwise guess from the filename.
|
||||
content_type = (
|
||||
resolved_content_type
|
||||
or mimetypes.guess_type(filename)[0]
|
||||
or 'application/octet-stream'
|
||||
)
|
||||
|
||||
# For security, force certain mimetypes to download instead of display
|
||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||
@ -697,10 +734,7 @@ class PromptServer():
|
||||
|
||||
@routes.get("/object_info")
|
||||
async def get_object_info(request):
|
||||
try:
|
||||
seed_assets(["models"])
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to seed assets: {e}")
|
||||
asset_seeder.start(roots=("models", "input", "output"))
|
||||
with folder_paths.cache_helper:
|
||||
out = {}
|
||||
for x in nodes.NODE_CLASS_MAPPINGS:
|
||||
|
||||
57
tests-unit/app_test/test_migrations.py
Normal file
57
tests-unit/app_test/test_migrations.py
Normal file
@ -0,0 +1,57 @@
|
||||
"""Test that Alembic migrations run cleanly on a file-backed SQLite DB.
|
||||
|
||||
This catches problems like unnamed FK constraints that prevent batch-mode
|
||||
drop_constraint from working on real SQLite files (see MB-2).
|
||||
|
||||
Migrations 0001 and 0002 are already shipped, so we only exercise
|
||||
upgrade/downgrade for 0003+.
|
||||
"""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
from alembic import command
|
||||
from alembic.config import Config
|
||||
|
||||
|
||||
# Oldest shipped revision — we upgrade to here as a baseline and never
|
||||
# downgrade past it.
|
||||
_BASELINE = "0002_merge_to_asset_references"
|
||||
|
||||
|
||||
def _make_config(db_path: str) -> Config:
|
||||
root = os.path.join(os.path.dirname(__file__), "../..")
|
||||
config_path = os.path.abspath(os.path.join(root, "alembic.ini"))
|
||||
scripts_path = os.path.abspath(os.path.join(root, "alembic_db"))
|
||||
|
||||
cfg = Config(config_path)
|
||||
cfg.set_main_option("script_location", scripts_path)
|
||||
cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
|
||||
return cfg
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def migration_db(tmp_path):
|
||||
"""Yield an alembic Config pre-upgraded to the baseline revision."""
|
||||
db_path = str(tmp_path / "test_migration.db")
|
||||
cfg = _make_config(db_path)
|
||||
command.upgrade(cfg, _BASELINE)
|
||||
yield cfg
|
||||
|
||||
|
||||
def test_upgrade_to_head(migration_db):
|
||||
"""Upgrade from baseline to head must succeed on a file-backed DB."""
|
||||
command.upgrade(migration_db, "head")
|
||||
|
||||
|
||||
def test_downgrade_to_baseline(migration_db):
|
||||
"""Upgrade to head then downgrade back to baseline."""
|
||||
command.upgrade(migration_db, "head")
|
||||
command.downgrade(migration_db, _BASELINE)
|
||||
|
||||
|
||||
def test_upgrade_downgrade_cycle(migration_db):
|
||||
"""Full cycle: upgrade → downgrade → upgrade again."""
|
||||
command.upgrade(migration_db, "head")
|
||||
command.downgrade(migration_db, _BASELINE)
|
||||
command.upgrade(migration_db, "head")
|
||||
@ -108,7 +108,7 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest)
|
||||
"main.py",
|
||||
f"--base-directory={str(comfy_tmp_base_dir)}",
|
||||
f"--database-url={db_url}",
|
||||
"--disable-assets-autoscan",
|
||||
"--enable-assets",
|
||||
"--listen",
|
||||
"127.0.0.1",
|
||||
"--port",
|
||||
@ -212,7 +212,7 @@ def asset_factory(http: requests.Session, api_base: str):
|
||||
|
||||
for aid in created:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -258,14 +258,4 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str):
|
||||
break
|
||||
for aid in ids:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
|
||||
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
|
||||
"""Force a fast sync/seed pass by calling the seed endpoint."""
|
||||
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
|
||||
time.sleep(0.2)
|
||||
|
||||
|
||||
def get_asset_filename(asset_hash: str, extension: str) -> str:
|
||||
return asset_hash.removeprefix("blake3:") + extension
|
||||
http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30)
|
||||
|
||||
28
tests-unit/assets_test/helpers.py
Normal file
28
tests-unit/assets_test/helpers.py
Normal file
@ -0,0 +1,28 @@
|
||||
"""Helper functions for assets integration tests."""
|
||||
import time
|
||||
|
||||
import requests
|
||||
|
||||
|
||||
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
|
||||
"""Force a synchronous sync/seed pass by calling the seed endpoint with wait=true.
|
||||
|
||||
Retries on 409 (already running) until the previous scan finishes.
|
||||
"""
|
||||
deadline = time.monotonic() + 60
|
||||
while True:
|
||||
r = session.post(
|
||||
base_url + "/api/assets/seed?wait=true",
|
||||
json={"roots": ["models", "input", "output"]},
|
||||
timeout=60,
|
||||
)
|
||||
if r.status_code != 409:
|
||||
assert r.status_code == 200, f"seed endpoint returned {r.status_code}: {r.text}"
|
||||
return
|
||||
if time.monotonic() > deadline:
|
||||
raise TimeoutError("seed endpoint stuck in 409 (already running)")
|
||||
time.sleep(0.25)
|
||||
|
||||
|
||||
def get_asset_filename(asset_hash: str, extension: str) -> str:
|
||||
return asset_hash.removeprefix("blake3:") + extension
|
||||
20
tests-unit/assets_test/queries/conftest.py
Normal file
20
tests-unit/assets_test/queries/conftest.py
Normal file
@ -0,0 +1,20 @@
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Base
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def session():
|
||||
"""In-memory SQLite session for fast unit tests."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
Base.metadata.create_all(engine)
|
||||
with Session(engine) as sess:
|
||||
yield sess
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def autoclean_unit_test_assets():
|
||||
"""Override parent autouse fixture - query tests don't need server cleanup."""
|
||||
yield
|
||||
187
tests-unit/assets_test/queries/test_asset.py
Normal file
187
tests-unit/assets_test/queries/test_asset.py
Normal file
@ -0,0 +1,187 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
asset_exists_by_hash,
|
||||
get_asset_by_hash,
|
||||
upsert_asset,
|
||||
bulk_insert_assets,
|
||||
update_asset_hash_and_mime,
|
||||
)
|
||||
|
||||
|
||||
class TestAssetExistsByHash:
|
||||
@pytest.mark.parametrize(
|
||||
"setup_hash,query_hash,expected",
|
||||
[
|
||||
(None, "nonexistent", False), # No asset exists
|
||||
("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash
|
||||
(None, "", False), # Null hash in DB doesn't match empty string
|
||||
],
|
||||
ids=["nonexistent", "existing", "null_hash_no_match"],
|
||||
)
|
||||
def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected):
|
||||
if setup_hash is not None or query_hash == "":
|
||||
asset = Asset(hash=setup_hash, size_bytes=100)
|
||||
session.add(asset)
|
||||
session.commit()
|
||||
|
||||
assert asset_exists_by_hash(session, asset_hash=query_hash) is expected
|
||||
|
||||
|
||||
class TestGetAssetByHash:
|
||||
@pytest.mark.parametrize(
|
||||
"setup_hash,query_hash,should_find",
|
||||
[
|
||||
(None, "nonexistent", False),
|
||||
("blake3:def456", "blake3:def456", True),
|
||||
],
|
||||
ids=["nonexistent", "existing"],
|
||||
)
|
||||
def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find):
|
||||
if setup_hash is not None:
|
||||
asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png")
|
||||
session.add(asset)
|
||||
session.commit()
|
||||
|
||||
result = get_asset_by_hash(session, asset_hash=query_hash)
|
||||
if should_find:
|
||||
assert result is not None
|
||||
assert result.size_bytes == 200
|
||||
assert result.mime_type == "image/png"
|
||||
else:
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestUpsertAsset:
|
||||
@pytest.mark.parametrize(
|
||||
"first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime",
|
||||
[
|
||||
# New asset creation
|
||||
(None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"),
|
||||
# Existing asset, same values - no update
|
||||
(500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"),
|
||||
# Existing asset with size 0, update with new values
|
||||
(0, None, 2048, "image/png", False, True, 2048, "image/png"),
|
||||
# Existing asset, second call with size 0 - no update
|
||||
(1000, None, 0, None, False, False, 1000, None),
|
||||
],
|
||||
ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"],
|
||||
)
|
||||
def test_upsert_scenarios(
|
||||
self,
|
||||
session: Session,
|
||||
first_size,
|
||||
first_mime,
|
||||
second_size,
|
||||
second_mime,
|
||||
expect_created,
|
||||
expect_updated,
|
||||
final_size,
|
||||
final_mime,
|
||||
):
|
||||
asset_hash = f"blake3:test_{first_size}_{second_size}"
|
||||
|
||||
# First upsert (if first_size is not None, we're testing the second call)
|
||||
if first_size is not None:
|
||||
upsert_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=first_size,
|
||||
mime_type=first_mime,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# The upsert call we're testing
|
||||
asset, created, updated = upsert_asset(
|
||||
session,
|
||||
asset_hash=asset_hash,
|
||||
size_bytes=second_size,
|
||||
mime_type=second_mime,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert created is expect_created
|
||||
assert updated is expect_updated
|
||||
assert asset.size_bytes == final_size
|
||||
assert asset.mime_type == final_mime
|
||||
|
||||
|
||||
class TestBulkInsertAssets:
|
||||
def test_inserts_multiple_assets(self, session: Session):
|
||||
now = get_utc_now()
|
||||
rows = [
|
||||
{"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": now},
|
||||
{"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": now},
|
||||
{"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": now},
|
||||
]
|
||||
bulk_insert_assets(session, rows)
|
||||
session.commit()
|
||||
|
||||
assets = session.query(Asset).all()
|
||||
assert len(assets) == 3
|
||||
hashes = {a.hash for a in assets}
|
||||
assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"}
|
||||
|
||||
def test_empty_list_is_noop(self, session: Session):
|
||||
bulk_insert_assets(session, [])
|
||||
session.commit()
|
||||
assert session.query(Asset).count() == 0
|
||||
|
||||
def test_handles_large_batch(self, session: Session):
|
||||
"""Test chunking logic with more rows than MAX_BIND_PARAMS allows."""
|
||||
now = get_utc_now()
|
||||
rows = [
|
||||
{"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": now}
|
||||
for i in range(200)
|
||||
]
|
||||
bulk_insert_assets(session, rows)
|
||||
session.commit()
|
||||
|
||||
assert session.query(Asset).count() == 200
|
||||
|
||||
|
||||
class TestMimeTypeImmutability:
|
||||
"""mime_type on Asset is write-once: set on first ingest, never overwritten."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"initial_mime,second_mime,expected_mime",
|
||||
[
|
||||
("image/png", "image/jpeg", "image/png"),
|
||||
(None, "image/png", "image/png"),
|
||||
],
|
||||
ids=["preserves_existing", "fills_null"],
|
||||
)
|
||||
def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime):
|
||||
h = f"blake3:upsert_{initial_mime}_{second_mime}"
|
||||
upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime)
|
||||
session.commit()
|
||||
|
||||
asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime)
|
||||
assert created is False
|
||||
assert asset.mime_type == expected_mime
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"initial_mime,update_mime,update_hash,expected_mime,expected_hash",
|
||||
[
|
||||
(None, "image/png", None, "image/png", "blake3:upd0"),
|
||||
("image/png", "image/jpeg", None, "image/png", "blake3:upd1"),
|
||||
("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"),
|
||||
],
|
||||
ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"],
|
||||
)
|
||||
def test_update_asset_hash_and_mime_immutability(
|
||||
self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash,
|
||||
):
|
||||
h = expected_hash.removesuffix("_new")
|
||||
asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
|
||||
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash)
|
||||
assert asset.mime_type == expected_mime
|
||||
assert asset.hash == expected_hash
|
||||
520
tests-unit/assets_test/queries/test_asset_info.py
Normal file
520
tests-unit/assets_test/queries/test_asset_info.py
Normal file
@ -0,0 +1,520 @@
|
||||
import time
|
||||
import uuid
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta
|
||||
from app.assets.database.queries import (
|
||||
reference_exists_for_asset_id,
|
||||
get_reference_by_id,
|
||||
insert_reference,
|
||||
get_or_create_reference,
|
||||
update_reference_timestamps,
|
||||
list_references_page,
|
||||
fetch_reference_asset_and_tags,
|
||||
fetch_reference_and_asset,
|
||||
update_reference_access_time,
|
||||
set_reference_metadata,
|
||||
delete_reference_by_id,
|
||||
set_reference_preview,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
get_reference_ids_by_ids,
|
||||
ensure_tags_exist,
|
||||
add_tags_to_reference,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_reference(
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
name: str = "test",
|
||||
owner_id: str = "",
|
||||
) -> AssetReference:
|
||||
now = get_utc_now()
|
||||
ref = AssetReference(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(ref)
|
||||
session.flush()
|
||||
return ref
|
||||
|
||||
|
||||
class TestReferenceExistsForAssetId:
|
||||
def test_returns_false_when_no_reference(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
assert reference_exists_for_asset_id(session, asset_id=asset.id) is False
|
||||
|
||||
def test_returns_true_when_reference_exists(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset)
|
||||
assert reference_exists_for_asset_id(session, asset_id=asset.id) is True
|
||||
|
||||
|
||||
class TestGetReferenceById:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
assert get_reference_by_id(session, reference_id="nonexistent") is None
|
||||
|
||||
def test_returns_reference(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset, name="myfile.txt")
|
||||
|
||||
result = get_reference_by_id(session, reference_id=ref.id)
|
||||
assert result is not None
|
||||
assert result.name == "myfile.txt"
|
||||
|
||||
|
||||
class TestListReferencesPage:
|
||||
def test_empty_db(self, session: Session):
|
||||
refs, tag_map, total = list_references_page(session)
|
||||
assert refs == []
|
||||
assert tag_map == {}
|
||||
assert total == 0
|
||||
|
||||
def test_returns_references_with_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset, name="test.bin")
|
||||
ensure_tags_exist(session, ["alpha", "beta"])
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"])
|
||||
session.commit()
|
||||
|
||||
refs, tag_map, total = list_references_page(session)
|
||||
assert len(refs) == 1
|
||||
assert refs[0].id == ref.id
|
||||
assert set(tag_map[ref.id]) == {"alpha", "beta"}
|
||||
assert total == 1
|
||||
|
||||
def test_name_contains_filter(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, name="model_v1.safetensors")
|
||||
_make_reference(session, asset, name="config.json")
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(session, name_contains="model")
|
||||
assert total == 1
|
||||
assert refs[0].name == "model_v1.safetensors"
|
||||
|
||||
def test_owner_visibility(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, name="public", owner_id="")
|
||||
_make_reference(session, asset, name="private", owner_id="user1")
|
||||
session.commit()
|
||||
|
||||
# Empty owner sees only public
|
||||
refs, _, total = list_references_page(session, owner_id="")
|
||||
assert total == 1
|
||||
assert refs[0].name == "public"
|
||||
|
||||
# Owner sees both
|
||||
refs, _, total = list_references_page(session, owner_id="user1")
|
||||
assert total == 2
|
||||
|
||||
def test_include_tags_filter(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref1 = _make_reference(session, asset, name="tagged")
|
||||
_make_reference(session, asset, name="untagged")
|
||||
ensure_tags_exist(session, ["wanted"])
|
||||
add_tags_to_reference(session, reference_id=ref1.id, tags=["wanted"])
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(session, include_tags=["wanted"])
|
||||
assert total == 1
|
||||
assert refs[0].name == "tagged"
|
||||
|
||||
def test_exclude_tags_filter(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, name="keep")
|
||||
ref_exclude = _make_reference(session, asset, name="exclude")
|
||||
ensure_tags_exist(session, ["bad"])
|
||||
add_tags_to_reference(session, reference_id=ref_exclude.id, tags=["bad"])
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(session, exclude_tags=["bad"])
|
||||
assert total == 1
|
||||
assert refs[0].name == "keep"
|
||||
|
||||
def test_sorting(self, session: Session):
|
||||
asset = _make_asset(session, "hash1", size=100)
|
||||
asset2 = _make_asset(session, "hash2", size=500)
|
||||
_make_reference(session, asset, name="small")
|
||||
_make_reference(session, asset2, name="large")
|
||||
session.commit()
|
||||
|
||||
refs, _, _ = list_references_page(session, sort="size", order="desc")
|
||||
assert refs[0].name == "large"
|
||||
|
||||
refs, _, _ = list_references_page(session, sort="name", order="asc")
|
||||
assert refs[0].name == "large"
|
||||
|
||||
|
||||
class TestFetchReferenceAssetAndTags:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
result = fetch_reference_asset_and_tags(session, "nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_returns_tuple(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset, name="test.bin")
|
||||
ensure_tags_exist(session, ["tag1"])
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["tag1"])
|
||||
session.commit()
|
||||
|
||||
result = fetch_reference_asset_and_tags(session, ref.id)
|
||||
assert result is not None
|
||||
ret_ref, ret_asset, ret_tags = result
|
||||
assert ret_ref.id == ref.id
|
||||
assert ret_asset.id == asset.id
|
||||
assert ret_tags == ["tag1"]
|
||||
|
||||
|
||||
class TestFetchReferenceAndAsset:
|
||||
def test_returns_none_for_nonexistent(self, session: Session):
|
||||
result = fetch_reference_and_asset(session, reference_id="nonexistent")
|
||||
assert result is None
|
||||
|
||||
def test_returns_tuple(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
session.commit()
|
||||
|
||||
result = fetch_reference_and_asset(session, reference_id=ref.id)
|
||||
assert result is not None
|
||||
ret_ref, ret_asset = result
|
||||
assert ret_ref.id == ref.id
|
||||
assert ret_asset.id == asset.id
|
||||
|
||||
|
||||
class TestUpdateReferenceAccessTime:
|
||||
def test_updates_last_access_time(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
original_time = ref.last_access_time
|
||||
session.commit()
|
||||
|
||||
import time
|
||||
time.sleep(0.01)
|
||||
|
||||
update_reference_access_time(session, reference_id=ref.id)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.last_access_time > original_time
|
||||
|
||||
|
||||
class TestDeleteReferenceById:
|
||||
def test_deletes_existing(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
session.commit()
|
||||
|
||||
result = delete_reference_by_id(session, reference_id=ref.id, owner_id="")
|
||||
assert result is True
|
||||
assert get_reference_by_id(session, reference_id=ref.id) is None
|
||||
|
||||
def test_returns_false_for_nonexistent(self, session: Session):
|
||||
result = delete_reference_by_id(session, reference_id="nonexistent", owner_id="")
|
||||
assert result is False
|
||||
|
||||
def test_respects_owner_visibility(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset, owner_id="user1")
|
||||
session.commit()
|
||||
|
||||
result = delete_reference_by_id(session, reference_id=ref.id, owner_id="user2")
|
||||
assert result is False
|
||||
assert get_reference_by_id(session, reference_id=ref.id) is not None
|
||||
|
||||
|
||||
class TestSetReferencePreview:
|
||||
def test_sets_preview(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
ref = _make_reference(session, asset)
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
session.commit()
|
||||
|
||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.preview_id == preview_ref.id
|
||||
|
||||
def test_clears_preview(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
ref = _make_reference(session, asset)
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
ref.preview_id = preview_ref.id
|
||||
session.commit()
|
||||
|
||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id=None)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.preview_id is None
|
||||
|
||||
def test_raises_for_nonexistent_reference(self, session: Session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None)
|
||||
|
||||
def test_raises_for_nonexistent_preview(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
session.commit()
|
||||
|
||||
with pytest.raises(ValueError, match="Preview AssetReference"):
|
||||
set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent")
|
||||
|
||||
|
||||
class TestInsertReference:
|
||||
def test_creates_new_reference(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = insert_reference(
|
||||
session, asset_id=asset.id, owner_id="user1", name="test.bin"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert ref is not None
|
||||
assert ref.name == "test.bin"
|
||||
assert ref.owner_id == "user1"
|
||||
|
||||
def test_allows_duplicate_names(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref1 = insert_reference(session, asset_id=asset.id, owner_id="user1", name="dup.bin")
|
||||
session.commit()
|
||||
|
||||
# Duplicate names are now allowed
|
||||
ref2 = insert_reference(
|
||||
session, asset_id=asset.id, owner_id="user1", name="dup.bin"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert ref1 is not None
|
||||
assert ref2 is not None
|
||||
assert ref1.id != ref2.id
|
||||
|
||||
|
||||
class TestGetOrCreateReference:
|
||||
def test_creates_new_reference(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref, created = get_or_create_reference(
|
||||
session, asset_id=asset.id, owner_id="user1", name="new.bin"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert created is True
|
||||
assert ref.name == "new.bin"
|
||||
|
||||
def test_always_creates_new_reference(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref1, created1 = get_or_create_reference(
|
||||
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Duplicate names are allowed, so always creates new
|
||||
ref2, created2 = get_or_create_reference(
|
||||
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert created1 is True
|
||||
assert created2 is True
|
||||
assert ref1.id != ref2.id
|
||||
|
||||
|
||||
class TestUpdateReferenceTimestamps:
|
||||
def test_updates_timestamps(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
original_updated_at = ref.updated_at
|
||||
session.commit()
|
||||
|
||||
time.sleep(0.01)
|
||||
update_reference_timestamps(session, ref)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.updated_at > original_updated_at
|
||||
|
||||
def test_updates_preview_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
preview_asset = _make_asset(session, "preview_hash")
|
||||
ref = _make_reference(session, asset)
|
||||
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||
session.commit()
|
||||
|
||||
update_reference_timestamps(session, ref, preview_id=preview_ref.id)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.preview_id == preview_ref.id
|
||||
|
||||
|
||||
class TestSetReferenceMetadata:
|
||||
def test_sets_metadata(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
session.commit()
|
||||
|
||||
set_reference_metadata(
|
||||
session, reference_id=ref.id, user_metadata={"key": "value"}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.user_metadata == {"key": "value"}
|
||||
# Check metadata table
|
||||
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
|
||||
assert len(meta) == 1
|
||||
assert meta[0].key == "key"
|
||||
assert meta[0].val_str == "value"
|
||||
|
||||
def test_replaces_existing_metadata(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
session.commit()
|
||||
|
||||
set_reference_metadata(
|
||||
session, reference_id=ref.id, user_metadata={"old": "data"}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
set_reference_metadata(
|
||||
session, reference_id=ref.id, user_metadata={"new": "data"}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
|
||||
assert len(meta) == 1
|
||||
assert meta[0].key == "new"
|
||||
|
||||
def test_clears_metadata_with_empty_dict(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
session.commit()
|
||||
|
||||
set_reference_metadata(
|
||||
session, reference_id=ref.id, user_metadata={"key": "value"}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
set_reference_metadata(
|
||||
session, reference_id=ref.id, user_metadata={}
|
||||
)
|
||||
session.commit()
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.user_metadata == {}
|
||||
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
|
||||
assert len(meta) == 0
|
||||
|
||||
def test_raises_for_nonexistent(self, session: Session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
set_reference_metadata(
|
||||
session, reference_id="nonexistent", user_metadata={"key": "value"}
|
||||
)
|
||||
|
||||
|
||||
class TestBulkInsertReferencesIgnoreConflicts:
|
||||
def test_inserts_multiple_references(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
now = get_utc_now()
|
||||
rows = [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"owner_id": "",
|
||||
"name": "bulk1.bin",
|
||||
"asset_id": asset.id,
|
||||
"preview_id": None,
|
||||
"user_metadata": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"owner_id": "",
|
||||
"name": "bulk2.bin",
|
||||
"asset_id": asset.id,
|
||||
"preview_id": None,
|
||||
"user_metadata": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
]
|
||||
bulk_insert_references_ignore_conflicts(session, rows)
|
||||
session.commit()
|
||||
|
||||
refs = session.query(AssetReference).all()
|
||||
assert len(refs) == 2
|
||||
|
||||
def test_allows_duplicate_names(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, name="existing.bin", owner_id="")
|
||||
session.commit()
|
||||
|
||||
now = get_utc_now()
|
||||
rows = [
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"owner_id": "",
|
||||
"name": "existing.bin",
|
||||
"asset_id": asset.id,
|
||||
"preview_id": None,
|
||||
"user_metadata": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
{
|
||||
"id": str(uuid.uuid4()),
|
||||
"owner_id": "",
|
||||
"name": "new.bin",
|
||||
"asset_id": asset.id,
|
||||
"preview_id": None,
|
||||
"user_metadata": {},
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
]
|
||||
bulk_insert_references_ignore_conflicts(session, rows)
|
||||
session.commit()
|
||||
|
||||
# Duplicate names allowed, so all 3 rows exist
|
||||
refs = session.query(AssetReference).all()
|
||||
assert len(refs) == 3
|
||||
|
||||
def test_empty_list_is_noop(self, session: Session):
|
||||
bulk_insert_references_ignore_conflicts(session, [])
|
||||
assert session.query(AssetReference).count() == 0
|
||||
|
||||
|
||||
class TestGetReferenceIdsByIds:
|
||||
def test_returns_existing_ids(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref1 = _make_reference(session, asset, name="a.bin")
|
||||
ref2 = _make_reference(session, asset, name="b.bin")
|
||||
session.commit()
|
||||
|
||||
found = get_reference_ids_by_ids(session, [ref1.id, ref2.id, "nonexistent"])
|
||||
|
||||
assert found == {ref1.id, ref2.id}
|
||||
|
||||
def test_empty_list_returns_empty(self, session: Session):
|
||||
found = get_reference_ids_by_ids(session, [])
|
||||
assert found == set()
|
||||
499
tests-unit/assets_test/queries/test_cache_state.py
Normal file
499
tests-unit/assets_test/queries/test_cache_state.py
Normal file
@ -0,0 +1,499 @@
|
||||
"""Tests for cache_state (AssetReference file path) query functions."""
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.queries import (
|
||||
list_references_by_asset_id,
|
||||
upsert_reference,
|
||||
get_unreferenced_unhashed_asset_ids,
|
||||
delete_assets_by_ids,
|
||||
get_references_for_prefixes,
|
||||
bulk_update_needs_verify,
|
||||
delete_references_by_ids,
|
||||
delete_orphaned_seed_asset,
|
||||
bulk_insert_references_ignore_conflicts,
|
||||
get_references_by_paths_and_asset_ids,
|
||||
mark_references_missing_outside_prefixes,
|
||||
restore_references_by_paths,
|
||||
)
|
||||
from app.assets.helpers import select_best_live_path, get_utc_now
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=size)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_reference(
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
file_path: str,
|
||||
name: str = "test",
|
||||
mtime_ns: int | None = None,
|
||||
needs_verify: bool = False,
|
||||
) -> AssetReference:
|
||||
now = get_utc_now()
|
||||
ref = AssetReference(
|
||||
asset_id=asset.id,
|
||||
file_path=file_path,
|
||||
name=name,
|
||||
mtime_ns=mtime_ns,
|
||||
needs_verify=needs_verify,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(ref)
|
||||
session.flush()
|
||||
return ref
|
||||
|
||||
|
||||
class TestListReferencesByAssetId:
|
||||
def test_returns_empty_for_no_references(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
||||
assert list(refs) == []
|
||||
|
||||
def test_returns_references_for_asset(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, "/path/a.bin", name="a")
|
||||
_make_reference(session, asset, "/path/b.bin", name="b")
|
||||
session.commit()
|
||||
|
||||
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
||||
paths = [r.file_path for r in refs]
|
||||
assert set(paths) == {"/path/a.bin", "/path/b.bin"}
|
||||
|
||||
def test_does_not_return_other_assets_references(self, session: Session):
|
||||
asset1 = _make_asset(session, "hash1")
|
||||
asset2 = _make_asset(session, "hash2")
|
||||
_make_reference(session, asset1, "/path/asset1.bin", name="a1")
|
||||
_make_reference(session, asset2, "/path/asset2.bin", name="a2")
|
||||
session.commit()
|
||||
|
||||
refs = list_references_by_asset_id(session, asset_id=asset1.id)
|
||||
paths = [r.file_path for r in refs]
|
||||
assert paths == ["/path/asset1.bin"]
|
||||
|
||||
|
||||
class TestSelectBestLivePath:
|
||||
def test_returns_empty_for_empty_list(self):
|
||||
result = select_best_live_path([])
|
||||
assert result == ""
|
||||
|
||||
def test_returns_empty_when_no_files_exist(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset, "/nonexistent/path.bin")
|
||||
session.commit()
|
||||
|
||||
result = select_best_live_path([ref])
|
||||
assert result == ""
|
||||
|
||||
def test_prefers_verified_path(self, session: Session, tmp_path):
|
||||
"""needs_verify=False should be preferred."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
|
||||
verified_file = tmp_path / "verified.bin"
|
||||
verified_file.write_bytes(b"data")
|
||||
|
||||
unverified_file = tmp_path / "unverified.bin"
|
||||
unverified_file.write_bytes(b"data")
|
||||
|
||||
ref_verified = _make_reference(
|
||||
session, asset, str(verified_file), name="verified", needs_verify=False
|
||||
)
|
||||
ref_unverified = _make_reference(
|
||||
session, asset, str(unverified_file), name="unverified", needs_verify=True
|
||||
)
|
||||
session.commit()
|
||||
|
||||
refs = [ref_unverified, ref_verified]
|
||||
result = select_best_live_path(refs)
|
||||
assert result == str(verified_file)
|
||||
|
||||
def test_falls_back_to_existing_unverified(self, session: Session, tmp_path):
|
||||
"""If all references need verification, return first existing path."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
|
||||
existing_file = tmp_path / "exists.bin"
|
||||
existing_file.write_bytes(b"data")
|
||||
|
||||
ref = _make_reference(session, asset, str(existing_file), needs_verify=True)
|
||||
session.commit()
|
||||
|
||||
result = select_best_live_path([ref])
|
||||
assert result == str(existing_file)
|
||||
|
||||
|
||||
class TestSelectBestLivePathWithMocking:
|
||||
def test_handles_missing_file_path_attr(self):
|
||||
"""Gracefully handle references with None file_path."""
|
||||
|
||||
class MockRef:
|
||||
file_path = None
|
||||
needs_verify = False
|
||||
|
||||
result = select_best_live_path([MockRef()])
|
||||
assert result == ""
|
||||
|
||||
|
||||
class TestUpsertReference:
|
||||
@pytest.mark.parametrize(
|
||||
"initial_mtime,second_mtime,expect_created,expect_updated,final_mtime",
|
||||
[
|
||||
# New reference creation
|
||||
(None, 12345, True, False, 12345),
|
||||
# Existing reference, same mtime - no update
|
||||
(100, 100, False, False, 100),
|
||||
# Existing reference, different mtime - update
|
||||
(100, 200, False, True, 200),
|
||||
],
|
||||
ids=["new_reference", "existing_no_change", "existing_update_mtime"],
|
||||
)
|
||||
def test_upsert_scenarios(
|
||||
self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime
|
||||
):
|
||||
asset = _make_asset(session, "hash1")
|
||||
file_path = f"/path_{initial_mtime}_{second_mtime}.bin"
|
||||
name = f"file_{initial_mtime}_{second_mtime}"
|
||||
|
||||
# Create initial reference if needed
|
||||
if initial_mtime is not None:
|
||||
upsert_reference(session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=initial_mtime)
|
||||
session.commit()
|
||||
|
||||
# The upsert call we're testing
|
||||
created, updated = upsert_reference(
|
||||
session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=second_mtime
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert created is expect_created
|
||||
assert updated is expect_updated
|
||||
ref = session.query(AssetReference).filter_by(file_path=file_path).one()
|
||||
assert ref.mtime_ns == final_mtime
|
||||
|
||||
def test_upsert_restores_missing_reference(self, session: Session):
|
||||
"""Upserting a reference that was marked missing should restore it."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
file_path = "/restored/file.bin"
|
||||
|
||||
ref = _make_reference(session, asset, file_path, mtime_ns=100)
|
||||
ref.is_missing = True
|
||||
session.commit()
|
||||
|
||||
created, updated = upsert_reference(
|
||||
session, asset_id=asset.id, file_path=file_path, name="restored", mtime_ns=100
|
||||
)
|
||||
session.commit()
|
||||
|
||||
assert created is False
|
||||
assert updated is True
|
||||
restored_ref = session.query(AssetReference).filter_by(file_path=file_path).one()
|
||||
assert restored_ref.is_missing is False
|
||||
|
||||
|
||||
class TestRestoreReferencesByPaths:
|
||||
def test_restores_missing_references(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
missing_path = "/missing/file.bin"
|
||||
active_path = "/active/file.bin"
|
||||
|
||||
missing_ref = _make_reference(session, asset, missing_path, name="missing")
|
||||
missing_ref.is_missing = True
|
||||
_make_reference(session, asset, active_path, name="active")
|
||||
session.commit()
|
||||
|
||||
restored = restore_references_by_paths(session, [missing_path])
|
||||
session.commit()
|
||||
|
||||
assert restored == 1
|
||||
ref = session.query(AssetReference).filter_by(file_path=missing_path).one()
|
||||
assert ref.is_missing is False
|
||||
|
||||
def test_empty_list_restores_nothing(self, session: Session):
|
||||
restored = restore_references_by_paths(session, [])
|
||||
assert restored == 0
|
||||
|
||||
|
||||
class TestMarkReferencesMissingOutsidePrefixes:
|
||||
def test_marks_references_missing_outside_prefixes(self, session: Session, tmp_path):
|
||||
asset = _make_asset(session, "hash1")
|
||||
valid_dir = tmp_path / "valid"
|
||||
valid_dir.mkdir()
|
||||
invalid_dir = tmp_path / "invalid"
|
||||
invalid_dir.mkdir()
|
||||
|
||||
valid_path = str(valid_dir / "file.bin")
|
||||
invalid_path = str(invalid_dir / "file.bin")
|
||||
|
||||
_make_reference(session, asset, valid_path, name="valid")
|
||||
_make_reference(session, asset, invalid_path, name="invalid")
|
||||
session.commit()
|
||||
|
||||
marked = mark_references_missing_outside_prefixes(session, [str(valid_dir)])
|
||||
session.commit()
|
||||
|
||||
assert marked == 1
|
||||
all_refs = session.query(AssetReference).all()
|
||||
assert len(all_refs) == 2
|
||||
|
||||
valid_ref = next(r for r in all_refs if r.file_path == valid_path)
|
||||
invalid_ref = next(r for r in all_refs if r.file_path == invalid_path)
|
||||
assert valid_ref.is_missing is False
|
||||
assert invalid_ref.is_missing is True
|
||||
|
||||
def test_empty_prefixes_marks_nothing(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, "/some/path.bin")
|
||||
session.commit()
|
||||
|
||||
marked = mark_references_missing_outside_prefixes(session, [])
|
||||
|
||||
assert marked == 0
|
||||
|
||||
|
||||
class TestGetUnreferencedUnhashedAssetIds:
|
||||
def test_returns_unreferenced_unhashed_assets(self, session: Session):
|
||||
# Unhashed asset (hash=None) with no references (no file_path)
|
||||
no_refs = _make_asset(session, hash_val=None)
|
||||
# Unhashed asset with active reference (not unreferenced)
|
||||
with_active_ref = _make_asset(session, hash_val=None)
|
||||
_make_reference(session, with_active_ref, "/has/ref.bin", name="has_ref")
|
||||
# Unhashed asset with only missing reference (should be unreferenced)
|
||||
with_missing_ref = _make_asset(session, hash_val=None)
|
||||
missing_ref = _make_reference(session, with_missing_ref, "/missing/ref.bin", name="missing_ref")
|
||||
missing_ref.is_missing = True
|
||||
# Regular asset (hash not None) - should not be returned
|
||||
_make_asset(session, hash_val="blake3:regular")
|
||||
session.commit()
|
||||
|
||||
unreferenced = get_unreferenced_unhashed_asset_ids(session)
|
||||
|
||||
assert no_refs.id in unreferenced
|
||||
assert with_missing_ref.id in unreferenced
|
||||
assert with_active_ref.id not in unreferenced
|
||||
|
||||
|
||||
class TestDeleteAssetsByIds:
|
||||
def test_deletes_assets_and_references(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, "/test/path.bin", name="test")
|
||||
session.commit()
|
||||
|
||||
deleted = delete_assets_by_ids(session, [asset.id])
|
||||
session.commit()
|
||||
|
||||
assert deleted == 1
|
||||
assert session.query(Asset).count() == 0
|
||||
assert session.query(AssetReference).count() == 0
|
||||
|
||||
def test_empty_list_deletes_nothing(self, session: Session):
|
||||
_make_asset(session, "hash1")
|
||||
session.commit()
|
||||
|
||||
deleted = delete_assets_by_ids(session, [])
|
||||
|
||||
assert deleted == 0
|
||||
assert session.query(Asset).count() == 1
|
||||
|
||||
|
||||
class TestGetReferencesForPrefixes:
|
||||
def test_returns_references_matching_prefix(self, session: Session, tmp_path):
|
||||
asset = _make_asset(session, "hash1")
|
||||
dir1 = tmp_path / "dir1"
|
||||
dir1.mkdir()
|
||||
dir2 = tmp_path / "dir2"
|
||||
dir2.mkdir()
|
||||
|
||||
path1 = str(dir1 / "file.bin")
|
||||
path2 = str(dir2 / "file.bin")
|
||||
|
||||
_make_reference(session, asset, path1, name="file1", mtime_ns=100)
|
||||
_make_reference(session, asset, path2, name="file2", mtime_ns=200)
|
||||
session.commit()
|
||||
|
||||
rows = get_references_for_prefixes(session, [str(dir1)])
|
||||
|
||||
assert len(rows) == 1
|
||||
assert rows[0].file_path == path1
|
||||
|
||||
def test_empty_prefixes_returns_empty(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, "/some/path.bin")
|
||||
session.commit()
|
||||
|
||||
rows = get_references_for_prefixes(session, [])
|
||||
|
||||
assert rows == []
|
||||
|
||||
|
||||
class TestBulkSetNeedsVerify:
|
||||
def test_sets_needs_verify_flag(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref1 = _make_reference(session, asset, "/path1.bin", needs_verify=False)
|
||||
ref2 = _make_reference(session, asset, "/path2.bin", needs_verify=False)
|
||||
session.commit()
|
||||
|
||||
updated = bulk_update_needs_verify(session, [ref1.id, ref2.id], True)
|
||||
session.commit()
|
||||
|
||||
assert updated == 2
|
||||
session.refresh(ref1)
|
||||
session.refresh(ref2)
|
||||
assert ref1.needs_verify is True
|
||||
assert ref2.needs_verify is True
|
||||
|
||||
def test_empty_list_updates_nothing(self, session: Session):
|
||||
updated = bulk_update_needs_verify(session, [], True)
|
||||
assert updated == 0
|
||||
|
||||
|
||||
class TestDeleteReferencesByIds:
|
||||
def test_deletes_references_by_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref1 = _make_reference(session, asset, "/path1.bin")
|
||||
_make_reference(session, asset, "/path2.bin")
|
||||
session.commit()
|
||||
|
||||
deleted = delete_references_by_ids(session, [ref1.id])
|
||||
session.commit()
|
||||
|
||||
assert deleted == 1
|
||||
assert session.query(AssetReference).count() == 1
|
||||
|
||||
def test_empty_list_deletes_nothing(self, session: Session):
|
||||
deleted = delete_references_by_ids(session, [])
|
||||
assert deleted == 0
|
||||
|
||||
|
||||
class TestDeleteOrphanedSeedAsset:
|
||||
@pytest.mark.parametrize(
|
||||
"create_asset,expected_deleted,expected_count",
|
||||
[
|
||||
(True, True, 0), # Existing asset gets deleted
|
||||
(False, False, 0), # Nonexistent returns False
|
||||
],
|
||||
ids=["deletes_existing", "nonexistent_returns_false"],
|
||||
)
|
||||
def test_delete_orphaned_seed_asset(
|
||||
self, session: Session, create_asset, expected_deleted, expected_count
|
||||
):
|
||||
asset_id = "nonexistent-id"
|
||||
if create_asset:
|
||||
asset = _make_asset(session, hash_val=None)
|
||||
asset_id = asset.id
|
||||
_make_reference(session, asset, "/test/path.bin", name="test")
|
||||
session.commit()
|
||||
|
||||
deleted = delete_orphaned_seed_asset(session, asset_id)
|
||||
if create_asset:
|
||||
session.commit()
|
||||
|
||||
assert deleted is expected_deleted
|
||||
assert session.query(Asset).count() == expected_count
|
||||
|
||||
|
||||
class TestBulkInsertReferencesIgnoreConflicts:
|
||||
def test_inserts_multiple_references(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
now = get_utc_now()
|
||||
rows = [
|
||||
{
|
||||
"asset_id": asset.id,
|
||||
"file_path": "/bulk1.bin",
|
||||
"name": "bulk1",
|
||||
"mtime_ns": 100,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
{
|
||||
"asset_id": asset.id,
|
||||
"file_path": "/bulk2.bin",
|
||||
"name": "bulk2",
|
||||
"mtime_ns": 200,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
]
|
||||
bulk_insert_references_ignore_conflicts(session, rows)
|
||||
session.commit()
|
||||
|
||||
assert session.query(AssetReference).count() == 2
|
||||
|
||||
def test_ignores_conflicts(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, "/existing.bin", mtime_ns=100)
|
||||
session.commit()
|
||||
|
||||
now = get_utc_now()
|
||||
rows = [
|
||||
{
|
||||
"asset_id": asset.id,
|
||||
"file_path": "/existing.bin",
|
||||
"name": "existing",
|
||||
"mtime_ns": 999,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
{
|
||||
"asset_id": asset.id,
|
||||
"file_path": "/new.bin",
|
||||
"name": "new",
|
||||
"mtime_ns": 200,
|
||||
"created_at": now,
|
||||
"updated_at": now,
|
||||
"last_access_time": now,
|
||||
},
|
||||
]
|
||||
bulk_insert_references_ignore_conflicts(session, rows)
|
||||
session.commit()
|
||||
|
||||
assert session.query(AssetReference).count() == 2
|
||||
existing = session.query(AssetReference).filter_by(file_path="/existing.bin").one()
|
||||
assert existing.mtime_ns == 100 # Original value preserved
|
||||
|
||||
def test_empty_list_is_noop(self, session: Session):
|
||||
bulk_insert_references_ignore_conflicts(session, [])
|
||||
assert session.query(AssetReference).count() == 0
|
||||
|
||||
|
||||
class TestGetReferencesByPathsAndAssetIds:
|
||||
def test_returns_matching_paths(self, session: Session):
|
||||
asset1 = _make_asset(session, "hash1")
|
||||
asset2 = _make_asset(session, "hash2")
|
||||
|
||||
_make_reference(session, asset1, "/path1.bin")
|
||||
_make_reference(session, asset2, "/path2.bin")
|
||||
session.commit()
|
||||
|
||||
path_to_asset = {
|
||||
"/path1.bin": asset1.id,
|
||||
"/path2.bin": asset2.id,
|
||||
}
|
||||
winners = get_references_by_paths_and_asset_ids(session, path_to_asset)
|
||||
|
||||
assert winners == {"/path1.bin", "/path2.bin"}
|
||||
|
||||
def test_excludes_non_matching_asset_ids(self, session: Session):
|
||||
asset1 = _make_asset(session, "hash1")
|
||||
asset2 = _make_asset(session, "hash2")
|
||||
|
||||
_make_reference(session, asset1, "/path1.bin")
|
||||
session.commit()
|
||||
|
||||
# Path exists but with different asset_id
|
||||
path_to_asset = {"/path1.bin": asset2.id}
|
||||
winners = get_references_by_paths_and_asset_ids(session, path_to_asset)
|
||||
|
||||
assert winners == set()
|
||||
|
||||
def test_empty_dict_returns_empty(self, session: Session):
|
||||
winners = get_references_by_paths_and_asset_ids(session, {})
|
||||
assert winners == set()
|
||||
231
tests-unit/assets_test/queries/test_metadata.py
Normal file
231
tests-unit/assets_test/queries/test_metadata.py
Normal file
@ -0,0 +1,231 @@
|
||||
"""Tests for metadata filtering logic in asset_reference queries."""
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta
|
||||
from app.assets.database.queries import list_references_page
|
||||
from app.assets.database.queries.asset_reference import convert_metadata_to_rows
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str) -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=1024)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_reference(
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
name: str,
|
||||
metadata: dict | None = None,
|
||||
system_metadata: dict | None = None,
|
||||
) -> AssetReference:
|
||||
now = get_utc_now()
|
||||
ref = AssetReference(
|
||||
owner_id="",
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
user_metadata=metadata,
|
||||
system_metadata=system_metadata,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(ref)
|
||||
session.flush()
|
||||
|
||||
# Build merged projection: {**system_metadata, **user_metadata}
|
||||
merged = {**(system_metadata or {}), **(metadata or {})}
|
||||
if merged:
|
||||
for key, val in merged.items():
|
||||
for row in convert_metadata_to_rows(key, val):
|
||||
meta_row = AssetReferenceMeta(
|
||||
asset_reference_id=ref.id,
|
||||
key=row["key"],
|
||||
ordinal=row.get("ordinal", 0),
|
||||
val_str=row.get("val_str"),
|
||||
val_num=row.get("val_num"),
|
||||
val_bool=row.get("val_bool"),
|
||||
val_json=row.get("val_json"),
|
||||
)
|
||||
session.add(meta_row)
|
||||
session.flush()
|
||||
|
||||
return ref
|
||||
|
||||
|
||||
class TestMetadataFilterByType:
|
||||
"""Table-driven tests for metadata filtering by different value types."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"match_meta,nomatch_meta,filter_key,filter_val",
|
||||
[
|
||||
# String matching
|
||||
({"category": "models"}, {"category": "images"}, "category", "models"),
|
||||
# Integer matching
|
||||
({"epoch": 5}, {"epoch": 10}, "epoch", 5),
|
||||
# Float matching
|
||||
({"score": 0.95}, {"score": 0.5}, "score", 0.95),
|
||||
# Boolean True matching
|
||||
({"enabled": True}, {"enabled": False}, "enabled", True),
|
||||
# Boolean False matching
|
||||
({"enabled": False}, {"enabled": True}, "enabled", False),
|
||||
],
|
||||
ids=["string", "int", "float", "bool_true", "bool_false"],
|
||||
)
|
||||
def test_filter_matches_correct_value(
|
||||
self, session: Session, match_meta, nomatch_meta, filter_key, filter_val
|
||||
):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, "match", match_meta)
|
||||
_make_reference(session, asset, "nomatch", nomatch_meta)
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={filter_key: filter_val}
|
||||
)
|
||||
assert total == 1
|
||||
assert refs[0].name == "match"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"stored_meta,filter_key,filter_val",
|
||||
[
|
||||
# String no match
|
||||
({"category": "models"}, "category", "other"),
|
||||
# Int no match
|
||||
({"epoch": 5}, "epoch", 99),
|
||||
# Float no match
|
||||
({"score": 0.5}, "score", 0.99),
|
||||
],
|
||||
ids=["string_no_match", "int_no_match", "float_no_match"],
|
||||
)
|
||||
def test_filter_returns_empty_when_no_match(
|
||||
self, session: Session, stored_meta, filter_key, filter_val
|
||||
):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, "item", stored_meta)
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={filter_key: filter_val}
|
||||
)
|
||||
assert total == 0
|
||||
|
||||
|
||||
class TestMetadataFilterNull:
|
||||
"""Tests for null/missing key filtering."""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"match_name,match_meta,nomatch_name,nomatch_meta,filter_key",
|
||||
[
|
||||
# Null matches missing key
|
||||
("missing_key", {}, "has_key", {"optional": "value"}, "optional"),
|
||||
# Null matches explicit null
|
||||
("explicit_null", {"nullable": None}, "has_value", {"nullable": "present"}, "nullable"),
|
||||
],
|
||||
ids=["missing_key", "explicit_null"],
|
||||
)
|
||||
def test_null_filter_matches(
|
||||
self, session: Session, match_name, match_meta, nomatch_name, nomatch_meta, filter_key
|
||||
):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, match_name, match_meta)
|
||||
_make_reference(session, asset, nomatch_name, nomatch_meta)
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(session, metadata_filter={filter_key: None})
|
||||
assert total == 1
|
||||
assert refs[0].name == match_name
|
||||
|
||||
|
||||
class TestMetadataFilterList:
|
||||
"""Tests for list-based (OR) filtering."""
|
||||
|
||||
def test_filter_by_list_matches_any(self, session: Session):
|
||||
"""List values should match ANY of the values (OR)."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, "cat_a", {"category": "a"})
|
||||
_make_reference(session, asset, "cat_b", {"category": "b"})
|
||||
_make_reference(session, asset, "cat_c", {"category": "c"})
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(session, metadata_filter={"category": ["a", "b"]})
|
||||
assert total == 2
|
||||
names = {r.name for r in refs}
|
||||
assert names == {"cat_a", "cat_b"}
|
||||
|
||||
|
||||
class TestMetadataFilterMultipleKeys:
|
||||
"""Tests for multiple filter keys (AND semantics)."""
|
||||
|
||||
def test_multiple_keys_must_all_match(self, session: Session):
|
||||
"""Multiple keys should ALL match (AND)."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, "match", {"type": "model", "version": 2})
|
||||
_make_reference(session, asset, "wrong_type", {"type": "config", "version": 2})
|
||||
_make_reference(session, asset, "wrong_version", {"type": "model", "version": 1})
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={"type": "model", "version": 2}
|
||||
)
|
||||
assert total == 1
|
||||
assert refs[0].name == "match"
|
||||
|
||||
|
||||
class TestMetadataFilterEmptyDict:
|
||||
"""Tests for empty filter behavior."""
|
||||
|
||||
def test_empty_filter_returns_all(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(session, asset, "a", {"key": "val"})
|
||||
_make_reference(session, asset, "b", {})
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(session, metadata_filter={})
|
||||
assert total == 2
|
||||
|
||||
|
||||
class TestSystemMetadataProjection:
|
||||
"""Tests for system_metadata merging into the filter projection."""
|
||||
|
||||
def test_system_metadata_keys_are_filterable(self, session: Session):
|
||||
"""system_metadata keys should appear in the merged projection."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(
|
||||
session, asset, "with_sys",
|
||||
system_metadata={"source": "scanner"},
|
||||
)
|
||||
_make_reference(session, asset, "without_sys")
|
||||
session.commit()
|
||||
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={"source": "scanner"}
|
||||
)
|
||||
assert total == 1
|
||||
assert refs[0].name == "with_sys"
|
||||
|
||||
def test_user_metadata_overrides_system_metadata(self, session: Session):
|
||||
"""user_metadata should win when both have the same key."""
|
||||
asset = _make_asset(session, "hash1")
|
||||
_make_reference(
|
||||
session, asset, "overridden",
|
||||
metadata={"origin": "user_upload"},
|
||||
system_metadata={"origin": "auto_scan"},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Should match the user value, not the system value
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={"origin": "user_upload"}
|
||||
)
|
||||
assert total == 1
|
||||
assert refs[0].name == "overridden"
|
||||
|
||||
# Should NOT match the system value (it was overridden)
|
||||
refs, _, total = list_references_page(
|
||||
session, metadata_filter={"origin": "auto_scan"}
|
||||
)
|
||||
assert total == 0
|
||||
366
tests-unit/assets_test/queries/test_tags.py
Normal file
366
tests-unit/assets_test/queries/test_tags.py
Normal file
@ -0,0 +1,366 @@
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, AssetReferenceMeta, Tag
|
||||
from app.assets.database.queries import (
|
||||
ensure_tags_exist,
|
||||
get_reference_tags,
|
||||
set_reference_tags,
|
||||
add_tags_to_reference,
|
||||
remove_tags_from_reference,
|
||||
add_missing_tag_for_asset_id,
|
||||
remove_missing_tag_for_asset_id,
|
||||
list_tags_with_usage,
|
||||
bulk_insert_tags_and_meta,
|
||||
)
|
||||
from app.assets.helpers import get_utc_now
|
||||
|
||||
|
||||
def _make_asset(session: Session, hash_val: str | None = None) -> Asset:
|
||||
asset = Asset(hash=hash_val, size_bytes=1024)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
return asset
|
||||
|
||||
|
||||
def _make_reference(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetReference:
|
||||
now = get_utc_now()
|
||||
ref = AssetReference(
|
||||
owner_id=owner_id,
|
||||
name=name,
|
||||
asset_id=asset.id,
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
)
|
||||
session.add(ref)
|
||||
session.flush()
|
||||
return ref
|
||||
|
||||
|
||||
class TestEnsureTagsExist:
|
||||
def test_creates_new_tags(self, session: Session):
|
||||
ensure_tags_exist(session, ["alpha", "beta"], tag_type="user")
|
||||
session.commit()
|
||||
|
||||
tags = session.query(Tag).all()
|
||||
assert {t.name for t in tags} == {"alpha", "beta"}
|
||||
|
||||
def test_is_idempotent(self, session: Session):
|
||||
ensure_tags_exist(session, ["alpha"], tag_type="user")
|
||||
ensure_tags_exist(session, ["alpha"], tag_type="user")
|
||||
session.commit()
|
||||
|
||||
assert session.query(Tag).count() == 1
|
||||
|
||||
def test_normalizes_tags(self, session: Session):
|
||||
ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"])
|
||||
session.commit()
|
||||
|
||||
tags = session.query(Tag).all()
|
||||
assert {t.name for t in tags} == {"alpha", "beta"}
|
||||
|
||||
def test_empty_list_is_noop(self, session: Session):
|
||||
ensure_tags_exist(session, [])
|
||||
session.commit()
|
||||
assert session.query(Tag).count() == 0
|
||||
|
||||
def test_tag_type_is_set(self, session: Session):
|
||||
ensure_tags_exist(session, ["system-tag"], tag_type="system")
|
||||
session.commit()
|
||||
|
||||
tag = session.query(Tag).filter_by(name="system-tag").one()
|
||||
assert tag.tag_type == "system"
|
||||
|
||||
|
||||
class TestGetReferenceTags:
|
||||
def test_returns_empty_for_no_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
tags = get_reference_tags(session, reference_id=ref.id)
|
||||
assert tags == []
|
||||
|
||||
def test_returns_tags_for_reference(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
ensure_tags_exist(session, ["tag1", "tag2"])
|
||||
session.add_all([
|
||||
AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag1", origin="manual", added_at=get_utc_now()),
|
||||
AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag2", origin="manual", added_at=get_utc_now()),
|
||||
])
|
||||
session.flush()
|
||||
|
||||
tags = get_reference_tags(session, reference_id=ref.id)
|
||||
assert set(tags) == {"tag1", "tag2"}
|
||||
|
||||
|
||||
class TestSetReferenceTags:
|
||||
def test_adds_new_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
result = set_reference_tags(session, reference_id=ref.id, tags=["a", "b"])
|
||||
session.commit()
|
||||
|
||||
assert set(result.added) == {"a", "b"}
|
||||
assert result.removed == []
|
||||
assert set(result.total) == {"a", "b"}
|
||||
|
||||
def test_removes_old_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["a", "b", "c"])
|
||||
result = set_reference_tags(session, reference_id=ref.id, tags=["a"])
|
||||
session.commit()
|
||||
|
||||
assert result.added == []
|
||||
assert set(result.removed) == {"b", "c"}
|
||||
assert result.total == ["a"]
|
||||
|
||||
def test_replaces_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
set_reference_tags(session, reference_id=ref.id, tags=["a", "b"])
|
||||
result = set_reference_tags(session, reference_id=ref.id, tags=["b", "c"])
|
||||
session.commit()
|
||||
|
||||
assert result.added == ["c"]
|
||||
assert result.removed == ["a"]
|
||||
assert set(result.total) == {"b", "c"}
|
||||
|
||||
|
||||
class TestAddTagsToReference:
|
||||
def test_adds_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"])
|
||||
session.commit()
|
||||
|
||||
assert set(result.added) == {"x", "y"}
|
||||
assert result.already_present == []
|
||||
|
||||
def test_reports_already_present(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["x"])
|
||||
result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"])
|
||||
session.commit()
|
||||
|
||||
assert result.added == ["y"]
|
||||
assert result.already_present == ["x"]
|
||||
|
||||
def test_raises_for_missing_reference(self, session: Session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
add_tags_to_reference(session, reference_id="nonexistent", tags=["x"])
|
||||
|
||||
|
||||
class TestRemoveTagsFromReference:
|
||||
def test_removes_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["a", "b", "c"])
|
||||
result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "b"])
|
||||
session.commit()
|
||||
|
||||
assert set(result.removed) == {"a", "b"}
|
||||
assert result.not_present == []
|
||||
assert result.total_tags == ["c"]
|
||||
|
||||
def test_reports_not_present(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["a"])
|
||||
result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "x"])
|
||||
session.commit()
|
||||
|
||||
assert result.removed == ["a"]
|
||||
assert result.not_present == ["x"]
|
||||
|
||||
def test_raises_for_missing_reference(self, session: Session):
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
remove_tags_from_reference(session, reference_id="nonexistent", tags=["x"])
|
||||
|
||||
|
||||
class TestMissingTagFunctions:
|
||||
def test_add_missing_tag_for_asset_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
session.commit()
|
||||
|
||||
tags = get_reference_tags(session, reference_id=ref.id)
|
||||
assert "missing" in tags
|
||||
|
||||
def test_add_missing_tag_is_idempotent(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
session.commit()
|
||||
|
||||
links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="missing").all()
|
||||
assert len(links) == 1
|
||||
|
||||
def test_remove_missing_tag_for_asset_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
session.commit()
|
||||
|
||||
tags = get_reference_tags(session, reference_id=ref.id)
|
||||
assert "missing" not in tags
|
||||
|
||||
|
||||
class TestListTagsWithUsage:
|
||||
def test_returns_tags_with_counts(self, session: Session):
|
||||
ensure_tags_exist(session, ["used", "unused"])
|
||||
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["used"])
|
||||
session.commit()
|
||||
|
||||
rows, total = list_tags_with_usage(session)
|
||||
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
assert tag_dict["used"] == 1
|
||||
assert tag_dict["unused"] == 0
|
||||
assert total == 2
|
||||
|
||||
def test_exclude_zero_counts(self, session: Session):
|
||||
ensure_tags_exist(session, ["used", "unused"])
|
||||
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["used"])
|
||||
session.commit()
|
||||
|
||||
rows, total = list_tags_with_usage(session, include_zero=False)
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
assert "used" in tag_names
|
||||
assert "unused" not in tag_names
|
||||
|
||||
def test_prefix_filter(self, session: Session):
|
||||
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
|
||||
session.commit()
|
||||
|
||||
rows, total = list_tags_with_usage(session, prefix="alph")
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
assert tag_names == {"alpha", "alphabet"}
|
||||
|
||||
def test_order_by_name(self, session: Session):
|
||||
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
|
||||
session.commit()
|
||||
|
||||
rows, _ = list_tags_with_usage(session, order="name_asc")
|
||||
|
||||
names = [name for name, _, _ in rows]
|
||||
assert names == ["alpha", "middle", "zebra"]
|
||||
|
||||
def test_owner_visibility(self, session: Session):
|
||||
ensure_tags_exist(session, ["shared-tag", "owner-tag"])
|
||||
|
||||
asset = _make_asset(session, "hash1")
|
||||
shared_ref = _make_reference(session, asset, name="shared", owner_id="")
|
||||
owner_ref = _make_reference(session, asset, name="owned", owner_id="user1")
|
||||
|
||||
add_tags_to_reference(session, reference_id=shared_ref.id, tags=["shared-tag"])
|
||||
add_tags_to_reference(session, reference_id=owner_ref.id, tags=["owner-tag"])
|
||||
session.commit()
|
||||
|
||||
# Empty owner sees only shared
|
||||
rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False)
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
assert tag_dict.get("shared-tag", 0) == 1
|
||||
assert tag_dict.get("owner-tag", 0) == 0
|
||||
|
||||
# User1 sees both
|
||||
rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False)
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
assert tag_dict.get("shared-tag", 0) == 1
|
||||
assert tag_dict.get("owner-tag", 0) == 1
|
||||
|
||||
|
||||
class TestBulkInsertTagsAndMeta:
|
||||
def test_inserts_tags(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
ensure_tags_exist(session, ["bulk-tag1", "bulk-tag2"])
|
||||
session.commit()
|
||||
|
||||
now = get_utc_now()
|
||||
tag_rows = [
|
||||
{"asset_reference_id": ref.id, "tag_name": "bulk-tag1", "origin": "manual", "added_at": now},
|
||||
{"asset_reference_id": ref.id, "tag_name": "bulk-tag2", "origin": "manual", "added_at": now},
|
||||
]
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
|
||||
session.commit()
|
||||
|
||||
tags = get_reference_tags(session, reference_id=ref.id)
|
||||
assert set(tags) == {"bulk-tag1", "bulk-tag2"}
|
||||
|
||||
def test_inserts_meta(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
session.commit()
|
||||
|
||||
meta_rows = [
|
||||
{
|
||||
"asset_reference_id": ref.id,
|
||||
"key": "meta-key",
|
||||
"ordinal": 0,
|
||||
"val_str": "meta-value",
|
||||
"val_num": None,
|
||||
"val_bool": None,
|
||||
"val_json": None,
|
||||
},
|
||||
]
|
||||
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=meta_rows)
|
||||
session.commit()
|
||||
|
||||
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
|
||||
assert len(meta) == 1
|
||||
assert meta[0].key == "meta-key"
|
||||
assert meta[0].val_str == "meta-value"
|
||||
|
||||
def test_ignores_conflicts(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
ensure_tags_exist(session, ["existing-tag"])
|
||||
add_tags_to_reference(session, reference_id=ref.id, tags=["existing-tag"])
|
||||
session.commit()
|
||||
|
||||
now = get_utc_now()
|
||||
tag_rows = [
|
||||
{"asset_reference_id": ref.id, "tag_name": "existing-tag", "origin": "duplicate", "added_at": now},
|
||||
]
|
||||
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
|
||||
session.commit()
|
||||
|
||||
# Should still have only one tag link
|
||||
links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="existing-tag").all()
|
||||
assert len(links) == 1
|
||||
# Origin should be original, not overwritten
|
||||
assert links[0].origin == "manual"
|
||||
|
||||
def test_empty_lists_is_noop(self, session: Session):
|
||||
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=[])
|
||||
assert session.query(AssetReferenceTag).count() == 0
|
||||
assert session.query(AssetReferenceMeta).count() == 0
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user