Merge branch 'master' into mps-text-encoder-device

This commit is contained in:
comfyanonymous 2026-03-16 19:32:09 -07:00 committed by GitHub
commit 46d76508c9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
122 changed files with 14283 additions and 3217 deletions

103
.github/scripts/check-ai-co-authors.sh vendored Executable file
View File

@ -0,0 +1,103 @@
#!/usr/bin/env bash
# Checks pull request commits for AI agent Co-authored-by trailers.
# Exits non-zero when any are found and prints fix instructions.
set -euo pipefail
base_sha="${1:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
head_sha="${2:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
# Known AI coding-agent trailer patterns (case-insensitive).
# Each entry is an extended-regex fragment matched against Co-authored-by lines.
AGENT_PATTERNS=(
# Anthropic — Claude Code / Amp
'noreply@anthropic\.com'
# Cursor
'cursoragent@cursor\.com'
# GitHub Copilot
'copilot-swe-agent\[bot\]'
'copilot@github\.com'
# OpenAI Codex
'noreply@openai\.com'
'codex@openai\.com'
# Aider
'aider@aider\.chat'
# Google — Gemini / Jules
'gemini@google\.com'
'jules@google\.com'
# Windsurf / Codeium
'@codeium\.com'
# Devin
'devin-ai-integration\[bot\]'
'devin@cognition\.ai'
'devin@cognition-labs\.com'
# Amazon Q Developer
'amazon-q-developer'
'@amazon\.com.*[Qq].[Dd]eveloper'
# Cline
'cline-bot'
'cline@cline\.ai'
# Continue
'continue-agent'
'continue@continue\.dev'
# Sourcegraph
'noreply@sourcegraph\.com'
# Generic catch-alls for common agent name patterns
'Co-authored-by:.*\b[Cc]laude\b'
'Co-authored-by:.*\b[Cc]opilot\b'
'Co-authored-by:.*\b[Cc]ursor\b'
'Co-authored-by:.*\b[Cc]odex\b'
'Co-authored-by:.*\b[Gg]emini\b'
'Co-authored-by:.*\b[Aa]ider\b'
'Co-authored-by:.*\b[Dd]evin\b'
'Co-authored-by:.*\b[Ww]indsurf\b'
'Co-authored-by:.*\b[Cc]line\b'
'Co-authored-by:.*\b[Aa]mazon Q\b'
'Co-authored-by:.*\b[Jj]ules\b'
'Co-authored-by:.*\bOpenCode\b'
)
# Build a single alternation regex from all patterns.
regex=""
for pattern in "${AGENT_PATTERNS[@]}"; do
if [[ -n "$regex" ]]; then
regex="${regex}|${pattern}"
else
regex="$pattern"
fi
done
# Collect Co-authored-by lines from every commit in the PR range.
violations=""
while IFS= read -r sha; do
message="$(git log -1 --format='%B' "$sha")"
matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
if [[ -z "$matched_lines" ]]; then
continue
fi
while IFS= read -r line; do
if echo "$line" | grep -iqE "$regex"; then
short="$(git log -1 --format='%h' "$sha")"
violations="${violations} ${short}: ${line}"$'\n'
fi
done <<< "$matched_lines"
done < <(git rev-list "${base_sha}..${head_sha}")
if [[ -n "$violations" ]]; then
echo "::error::AI agent Co-authored-by trailers detected in PR commits."
echo ""
echo "The following commits contain Co-authored-by trailers from AI coding agents:"
echo ""
echo "$violations"
echo "These trailers should be removed before merging."
echo ""
echo "To fix, rewrite the commit messages with:"
echo " git rebase -i ${base_sha}"
echo ""
echo "and remove the Co-authored-by lines, then force-push your branch."
echo ""
echo "If you believe this is a false positive, please open an issue."
exit 1
fi
echo "No AI agent Co-authored-by trailers found."

View File

@ -0,0 +1,19 @@
name: Check AI Co-Authors
on:
pull_request:
branches: ['*']
jobs:
check-ai-co-authors:
name: Check for AI agent co-author trailers
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Check commits for AI co-author trailers
run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"

View File

@ -38,6 +38,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
## Get Started ## Get Started
### Local
#### [Desktop Application](https://www.comfy.org/download) #### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started. - The easiest way to get started.
- Available on Windows & macOS. - Available on Windows & macOS.
@ -49,8 +51,13 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
#### [Manual Install](#manual-install-windows-linux) #### [Manual Install](#manual-install-windows-linux)
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend). Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/) ### Cloud
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
#### [Comfy Cloud](https://www.comfy.org/cloud)
- Our official paid cloud version for those who can't afford local hardware.
## Examples
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.

View File

@ -8,7 +8,7 @@ from alembic import context
config = context.config config = context.config
from app.database.models import Base from app.database.models import Base, NAMING_CONVENTION
target_metadata = Base.metadata target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
@ -51,7 +51,10 @@ def run_migrations_online() -> None:
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(
connection=connection, target_metadata=target_metadata connection=connection,
target_metadata=target_metadata,
render_as_batch=True,
naming_convention=NAMING_CONVENTION,
) )
with context.begin_transaction(): with context.begin_transaction():

View File

@ -0,0 +1,267 @@
"""
Merge AssetInfo and AssetCacheState into unified asset_references table.
This migration drops old tables and creates the new unified schema.
All existing data is discarded.
Revision ID: 0002_merge_to_asset_references
Revises: 0001_assets
Create Date: 2025-02-11
"""
from alembic import op
import sqlalchemy as sa
revision = "0002_merge_to_asset_references"
down_revision = "0001_assets"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop old tables (order matters due to FK constraints)
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
# Truncate assets table (cascades handled by dropping dependent tables first)
op.execute("DELETE FROM assets")
# Create asset_references table
op.create_table(
"asset_references",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column(
"asset_id",
sa.String(length=36),
sa.ForeignKey("assets.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("file_path", sa.Text(), nullable=True),
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column(
"needs_verify",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
sa.Column(
"is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column(
"preview_id",
sa.String(length=36),
sa.ForeignKey("assets.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=False), nullable=True),
sa.CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
),
sa.CheckConstraint(
"enrichment_level >= 0 AND enrichment_level <= 2",
name="ck_ar_enrichment_level_range",
),
)
op.create_index(
"uq_asset_references_file_path", "asset_references", ["file_path"], unique=True
)
op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"])
op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"])
op.create_index("ix_asset_references_name", "asset_references", ["name"])
op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"])
op.create_index(
"ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"]
)
op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"])
op.create_index(
"ix_asset_references_last_access_time", "asset_references", ["last_access_time"]
)
op.create_index(
"ix_asset_references_owner_name", "asset_references", ["owner_id", "name"]
)
op.create_index("ix_asset_references_deleted_at", "asset_references", ["deleted_at"])
# Create asset_reference_tags table
op.create_table(
"asset_reference_tags",
sa.Column(
"asset_reference_id",
sa.String(length=36),
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"tag_name",
sa.String(length=512),
sa.ForeignKey("tags.name", ondelete="RESTRICT"),
nullable=False,
),
sa.Column(
"origin", sa.String(length=32), nullable=False, server_default="manual"
),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint(
"asset_reference_id", "tag_name", name="pk_asset_reference_tags"
),
)
op.create_index(
"ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"]
)
op.create_index(
"ix_asset_reference_tags_asset_reference_id",
"asset_reference_tags",
["asset_reference_id"],
)
# Create asset_reference_meta table
op.create_table(
"asset_reference_meta",
sa.Column(
"asset_reference_id",
sa.String(length=36),
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint(
"asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta"
),
)
op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"])
op.create_index(
"ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"]
)
op.create_index(
"ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"]
)
op.create_index(
"ix_asset_reference_meta_key_val_bool",
"asset_reference_meta",
["key", "val_bool"],
)
def downgrade() -> None:
"""Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema.
NOTE: Data is not recoverable. The upgrade discards all rows from the old
tables and truncates assets. After downgrade the old schema will be empty.
A filesystem rescan will repopulate data once the older code is running.
"""
# Drop new tables (order matters due to FK constraints)
op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta")
op.drop_table("asset_reference_meta")
op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags")
op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags")
op.drop_table("asset_reference_tags")
op.drop_index("ix_asset_references_deleted_at", table_name="asset_references")
op.drop_index("ix_asset_references_owner_name", table_name="asset_references")
op.drop_index("ix_asset_references_last_access_time", table_name="asset_references")
op.drop_index("ix_asset_references_created_at", table_name="asset_references")
op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references")
op.drop_index("ix_asset_references_is_missing", table_name="asset_references")
op.drop_index("ix_asset_references_name", table_name="asset_references")
op.drop_index("ix_asset_references_owner_id", table_name="asset_references")
op.drop_index("ix_asset_references_asset_id", table_name="asset_references")
op.drop_index("uq_asset_references_file_path", table_name="asset_references")
op.drop_table("asset_references")
# Truncate assets (upgrade deleted all rows; downgrade starts fresh too)
op.execute("DELETE FROM assets")
# Recreate old tables from 0001_assets schema
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False),
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])

View File

@ -0,0 +1,98 @@
"""
Add system_metadata and job_id columns to asset_references.
Change preview_id FK from assets.id to asset_references.id.
Revision ID: 0003_add_metadata_job_id
Revises: 0002_merge_to_asset_references
Create Date: 2026-03-09
"""
from alembic import op
import sqlalchemy as sa
from app.database.models import NAMING_CONVENTION
revision = "0003_add_metadata_job_id"
down_revision = "0002_merge_to_asset_references"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.add_column(
sa.Column("system_metadata", sa.JSON(), nullable=True)
)
batch_op.add_column(
sa.Column("job_id", sa.String(length=36), nullable=True)
)
# Change preview_id FK from assets.id to asset_references.id (self-ref).
# Existing values are asset-content IDs that won't match reference IDs,
# so null them out first.
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_constraint(
"fk_asset_references_preview_id_assets", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_asset_references",
"asset_references",
["preview_id"],
["id"],
ondelete="SET NULL",
)
batch_op.create_index(
"ix_asset_references_preview_id", ["preview_id"]
)
# Purge any all-null meta rows before adding the constraint
op.execute(
"DELETE FROM asset_reference_meta"
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
)
with op.batch_alter_table("asset_reference_meta") as batch_op:
batch_op.create_check_constraint(
"ck_asset_reference_meta_has_value",
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
)
def downgrade() -> None:
# SQLite doesn't reflect CHECK constraints, so we must declare it
# explicitly via table_args for the batch recreate to find it.
# Use the fully-rendered constraint name to avoid the naming convention
# doubling the prefix.
with op.batch_alter_table(
"asset_reference_meta",
table_args=[
sa.CheckConstraint(
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
name="ck_asset_reference_meta_has_value",
),
],
) as batch_op:
batch_op.drop_constraint(
"ck_asset_reference_meta_has_value", type_="check"
)
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_index("ix_asset_references_preview_id")
batch_op.drop_constraint(
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_assets",
"assets",
["preview_id"],
["id"],
ondelete="SET NULL",
)
with op.batch_alter_table("asset_references") as batch_op:
batch_op.drop_column("job_id")
batch_op.drop_column("system_metadata")

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,8 @@
import json import json
from dataclasses import dataclass
from typing import Any, Literal from typing import Any, Literal
from app.assets.helpers import validate_blake3_hash
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
ConfigDict, ConfigDict,
@ -10,6 +12,43 @@ from pydantic import (
model_validator, model_validator,
) )
class UploadError(Exception):
"""Error during upload parsing with HTTP status and code."""
def __init__(self, status: int, code: str, message: str):
super().__init__(message)
self.status = status
self.code = code
self.message = message
class AssetValidationError(Exception):
"""Validation error in asset processing (invalid tags, metadata, etc.)."""
def __init__(self, code: str, message: str):
super().__init__(message)
self.code = code
self.message = message
@dataclass
class ParsedUpload:
"""Result of parsing a multipart upload request."""
file_present: bool
file_written: int
file_client_name: str | None
tmp_path: str | None
tags_raw: list[str]
provided_name: str | None
user_metadata_raw: str | None
provided_hash: str | None
provided_hash_exists: bool | None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
class ListAssetsQuery(BaseModel): class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list) include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list) exclude_tags: list[str] = Field(default_factory=list)
@ -21,7 +60,9 @@ class ListAssetsQuery(BaseModel):
limit: conint(ge=1, le=500) = 20 limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0 offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at" sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
"created_at"
)
order: Literal["asc", "desc"] = "desc" order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before") @field_validator("include_tags", "exclude_tags", mode="before")
@ -59,11 +100,17 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel): class UpdateAssetBody(BaseModel):
name: str | None = None name: str | None = None
user_metadata: dict[str, Any] | None = None user_metadata: dict[str, Any] | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
@model_validator(mode="after") @model_validator(mode="after")
def _at_least_one(self): def _validate_at_least_one_field(self):
if self.name is None and self.user_metadata is None: if all(
raise ValueError("Provide at least one of: name, user_metadata.") v is None
for v in (self.name, self.user_metadata, self.preview_id)
):
raise ValueError(
"Provide at least one of: name, user_metadata, preview_id."
)
return self return self
@ -71,26 +118,20 @@ class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str hash: str
name: str name: str | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict) user_metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
@field_validator("hash") @field_validator("hash")
@classmethod @classmethod
def _require_blake3(cls, v): def _require_blake3(cls, v):
s = (v or "").strip().lower() return validate_blake3_hash(v or "")
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
@field_validator("tags", mode="before") @field_validator("tags", mode="before")
@classmethod @classmethod
def _tags_norm(cls, v): def _normalize_tags_field(cls, v):
if v is None: if v is None:
return [] return []
if isinstance(v, list): if isinstance(v, list):
@ -107,6 +148,44 @@ class CreateFromHashBody(BaseModel):
return [] return []
class TagsRefineQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel): class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@ -154,38 +233,36 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel): class UploadAssetSpec(BaseModel):
"""Upload Asset operation. """Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths - tags: optional list; if provided, first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category
- name: display name - name: display name
- user_metadata: arbitrary JSON object (optional) - user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path - hash: optional canonical 'blake3:<hex>' for validation / fast-path
- mime_type: optional MIME type override
- preview_id: optional asset_reference ID for preview
Files created via this endpoint are stored on disk using the **content hash** as the filename stem Files are stored using the content hash as filename stem.
and the original extension is preserved when available.
""" """
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1) tags: list[str] = Field(default_factory=list)
name: str | None = Field(default=None, max_length=512, description="Display Name") name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict) user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None) hash: str | None = Field(default=None)
mime_type: str | None = Field(default=None)
preview_id: str | None = Field(default=None) # references an asset_reference id
@field_validator("hash", mode="before") @field_validator("hash", mode="before")
@classmethod @classmethod
def _parse_hash(cls, v): def _parse_hash(cls, v):
if v is None: if v is None:
return None return None
s = str(v).strip().lower() s = str(v).strip()
if not s: if not s:
return None return None
if ":" not in s: return validate_blake3_hash(s)
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
@field_validator("tags", mode="before") @field_validator("tags", mode="before")
@classmethod @classmethod
@ -254,11 +331,13 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def _validate_order(self): def _validate_order(self):
if not self.tags: if not self.tags:
raise ValueError("tags must be provided and non-empty") raise ValueError("at least one tag is required for uploads")
root = self.tags[0] root = self.tags[0]
if root not in {"models", "input", "output"}: if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output") raise ValueError("first tag must be one of: models, input, output")
if root == "models": if root == "models":
if len(self.tags) < 2: if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag") raise ValueError(
"models uploads require a category tag as the second tag"
)
return self return self

View File

@ -4,7 +4,10 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel): class Asset(BaseModel):
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
``id`` here is the AssetReference id, not the content-addressed Asset id."""
id: str id: str
name: str name: str
asset_hash: str | None = None asset_hash: str | None = None
@ -12,61 +15,33 @@ class AssetSummary(BaseModel):
mime_type: str | None = None mime_type: str | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
preview_url: str | None = None preview_url: str | None = None
created_at: datetime | None = None preview_id: str | None = None # references an asset_reference id, not an asset id
updated_at: datetime | None = None user_metadata: dict[str, Any] = Field(default_factory=dict)
is_immutable: bool = False
metadata: dict[str, Any] | None = None
job_id: str | None = None
prompt_id: str | None = None # deprecated: use job_id
created_at: datetime
updated_at: datetime
last_access_time: datetime | None = None last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time") @field_serializer("created_at", "updated_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info): def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None return v.isoformat() if v else None
class AssetCreated(Asset):
created_new: bool
class AssetsList(BaseModel): class AssetsList(BaseModel):
assets: list[AssetSummary] assets: list[Asset]
total: int total: int
has_more: bool has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel): class TagUsage(BaseModel):
name: str name: str
count: int count: int
@ -91,3 +66,7 @@ class TagsRemove(BaseModel):
removed: list[str] = Field(default_factory=list) removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list) not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list) total_tags: list[str] = Field(default_factory=list)
class TagHistogram(BaseModel):
tag_counts: dict[str, int]

185
app/assets/api/upload.py Normal file
View File

@ -0,0 +1,185 @@
import logging
import os
import uuid
from typing import Callable
from aiohttp import web
import folder_paths
from app.assets.api.schemas_in import ParsedUpload, UploadError
from app.assets.helpers import validate_blake3_hash
def normalize_and_validate_hash(s: str) -> str:
"""Validate and normalize a hash string.
Returns canonical 'blake3:<hex>' or raises UploadError.
"""
try:
return validate_blake3_hash(s)
except ValueError:
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
async def parse_multipart_upload(
request: web.Request,
check_hash_exists: Callable[[str], bool],
) -> ParsedUpload:
"""
Parse a multipart/form-data upload request.
Args:
request: The aiohttp request
check_hash_exists: Callable(hash_str) -> bool to check if a hash exists
Returns:
ParsedUpload with parsed fields and temp file path
Raises:
UploadError: On validation or I/O errors
"""
if not (request.content_type or "").lower().startswith("multipart/"):
raise UploadError(
415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads."
)
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
raise UploadError(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
if s:
provided_hash = normalize_and_validate_hash(s)
try:
provided_hash_exists = check_hash_exists(provided_hash)
except Exception as e:
logging.exception(
"check_hash_exists failed for hash=%s: %s", provided_hash, e
)
raise UploadError(
500,
"HASH_CHECK_FAILED",
"Backend error while checking asset hash.",
)
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# Hash exists - drain file but don't write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file."
)
continue
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
delete_temp_file_if_exists(tmp_path)
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file."
)
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
elif fname == "id":
raise UploadError(
400,
"UNSUPPORTED_FIELD",
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
)
elif fname == "mime_type":
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."
)
if (
file_present
and file_written == 0
and not (provided_hash and provided_hash_exists)
):
delete_temp_file_if_exists(tmp_path)
raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
return ParsedUpload(
file_present=file_present,
file_written=file_written,
file_client_name=file_client_name,
tmp_path=tmp_path,
tags_raw=tags_raw,
provided_name=provided_name,
user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists,
provided_mime_type=provided_mime_type,
provided_preview_id=provided_preview_id,
)
def delete_temp_file_if_exists(tmp_path: str | None) -> None:
"""Safely remove a temp file and its parent directory if empty."""
if tmp_path:
try:
if os.path.exists(tmp_path):
os.remove(tmp_path)
except OSError as e:
logging.debug("Failed to delete temp file %s: %s", tmp_path, e)
try:
parent = os.path.dirname(tmp_path)
if parent and os.path.isdir(parent):
os.rmdir(parent) # only succeeds if empty
except OSError:
pass

View File

@ -1,204 +0,0 @@
import os
import uuid
import sqlalchemy
from typing import Iterable
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import utcnow
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
MAX_BIND_PARAMS = 800
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def seed_from_paths_batch(
session: Session,
*,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = sqlite.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def bulk_insert_tags_and_meta(
session: Session,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_links = (
sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
session.execute(ins_links, chunk)
if meta_rows:
ins_meta = (
sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
session.execute(ins_meta, chunk)

View File

@ -2,8 +2,8 @@ from __future__ import annotations
import uuid import uuid
from datetime import datetime from datetime import datetime
from typing import Any from typing import Any
from sqlalchemy import ( from sqlalchemy import (
JSON, JSON,
BigInteger, BigInteger,
@ -16,47 +16,36 @@ from sqlalchemy import (
Numeric, Numeric,
String, String,
Text, Text,
UniqueConstraint,
) )
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from app.assets.helpers import utcnow from app.assets.helpers import get_utc_now
from app.database.models import to_dict, Base from app.database.models import Base
class Asset(Base): class Asset(Base):
__tablename__ = "assets" __tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4())) id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
hash: Mapped[str | None] = mapped_column(String(256), nullable=True) hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0) size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255)) mime_type: Mapped[str | None] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow DateTime(timezone=False), nullable=False, default=get_utc_now
) )
infos: Mapped[list[AssetInfo]] = relationship( references: Mapped[list[AssetReference]] = relationship(
"AssetInfo", "AssetReference",
back_populates="asset", back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id), primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id], foreign_keys=lambda: [AssetReference.asset_id],
cascade="all,delete-orphan", cascade="all,delete-orphan",
passive_deletes=True, passive_deletes=True,
) )
preview_of: Mapped[list[AssetInfo]] = relationship( # preview_id on AssetReference is a self-referential FK to asset_references.id
"AssetInfo",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
viewonly=True,
)
cache_states: Mapped[list[AssetCacheState]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = ( __table_args__ = (
Index("uq_assets_hash", "hash", unique=True), Index("uq_assets_hash", "hash", unique=True),
@ -64,108 +53,126 @@ class Asset(Base):
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"), CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
) )
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>" return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base): class AssetReference(Base):
__tablename__ = "asset_cache_state" """Unified model combining file cache state and user-facing metadata.
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True) Each row represents either:
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False) - A filesystem reference (file_path is set) with cache state
file_path: Mapped[str] = mapped_column(Text, nullable=False) - An API-created reference (file_path is NULL) without cache state
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True) """
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
asset: Mapped[Asset] = relationship(back_populates="cache_states") __tablename__ = "asset_references"
__table_args__ = ( id: Mapped[str] = mapped_column(
Index("ix_asset_cache_state_file_path", "file_path"), String(36), primary_key=True, default=lambda: str(uuid.uuid4())
Index("ix_asset_cache_state_asset_id", "asset_id"), )
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"), asset_id: Mapped[str] = mapped_column(
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"), String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
) )
def to_dict(self, include_none: bool = False) -> dict[str, Any]: # Cache state fields (from former AssetCacheState)
return to_dict(self, include_none=include_none) file_path: Mapped[str | None] = mapped_column(Text, nullable=True)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
def __repr__(self) -> str: # Info fields (from former AssetInfo)
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False) name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False) preview_id: Mapped[str | None] = mapped_column(
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL")) String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True)) )
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) JSON(none_as_null=True)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow) )
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True), nullable=True, default=None
)
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
deleted_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=False), nullable=True, default=None
)
asset: Mapped[Asset] = relationship( asset: Mapped[Asset] = relationship(
"Asset", "Asset",
back_populates="infos", back_populates="references",
foreign_keys=[asset_id], foreign_keys=[asset_id],
lazy="selectin", lazy="selectin",
) )
preview_asset: Mapped[Asset | None] = relationship( preview_ref: Mapped[AssetReference | None] = relationship(
"Asset", "AssetReference",
back_populates="preview_of",
foreign_keys=[preview_id], foreign_keys=[preview_id],
remote_side=lambda: [AssetReference.id],
) )
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship( metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
back_populates="asset_info", back_populates="asset_reference",
cascade="all,delete-orphan", cascade="all,delete-orphan",
passive_deletes=True, passive_deletes=True,
) )
tag_links: Mapped[list[AssetInfoTag]] = relationship( tag_links: Mapped[list[AssetReferenceTag]] = relationship(
back_populates="asset_info", back_populates="asset_reference",
cascade="all,delete-orphan", cascade="all,delete-orphan",
passive_deletes=True, passive_deletes=True,
overlaps="tags,asset_infos", overlaps="tags,asset_references",
) )
tags: Mapped[list[Tag]] = relationship( tags: Mapped[list[Tag]] = relationship(
secondary="asset_info_tags", secondary="asset_reference_tags",
back_populates="asset_infos", back_populates="asset_references",
lazy="selectin", lazy="selectin",
viewonly=True, viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag", overlaps="tag_links,asset_reference_links,asset_references,tag",
) )
__table_args__ = ( __table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"), Index("uq_asset_references_file_path", "file_path", unique=True),
Index("ix_assets_info_owner_name", "owner_id", "name"), Index("ix_asset_references_asset_id", "asset_id"),
Index("ix_assets_info_owner_id", "owner_id"), Index("ix_asset_references_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"), Index("ix_asset_references_name", "name"),
Index("ix_assets_info_name", "name"), Index("ix_asset_references_is_missing", "is_missing"),
Index("ix_assets_info_created_at", "created_at"), Index("ix_asset_references_enrichment_level", "enrichment_level"),
Index("ix_assets_info_last_access_time", "last_access_time"), Index("ix_asset_references_created_at", "created_at"),
Index("ix_asset_references_last_access_time", "last_access_time"),
Index("ix_asset_references_deleted_at", "deleted_at"),
Index("ix_asset_references_preview_id", "preview_id"),
Index("ix_asset_references_owner_name", "owner_id", "name"),
CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
),
CheckConstraint(
"enrichment_level >= 0 AND enrichment_level <= 2",
name="ck_ar_enrichment_level_range",
),
) )
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>" path_part = f" path={self.file_path!r}" if self.file_path else ""
return f"<AssetReference id={self.id} name={self.name!r}{path_part}>"
class AssetInfoMeta(Base): class AssetReferenceMeta(Base):
__tablename__ = "asset_info_meta" __tablename__ = "asset_reference_meta"
asset_info_id: Mapped[str] = mapped_column( asset_reference_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True String(36),
ForeignKey("asset_references.id", ondelete="CASCADE"),
primary_key=True,
) )
key: Mapped[str] = mapped_column(String(256), primary_key=True) key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0) ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
@ -175,36 +182,44 @@ class AssetInfoMeta(Base):
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True) val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True) val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries") asset_reference: Mapped[AssetReference] = relationship(
back_populates="metadata_entries"
)
__table_args__ = ( __table_args__ = (
Index("ix_asset_info_meta_key", "key"), Index("ix_asset_reference_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"), Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"), Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"), Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
CheckConstraint(
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
name="has_value",
),
) )
class AssetInfoTag(Base): class AssetReferenceTag(Base):
__tablename__ = "asset_info_tags" __tablename__ = "asset_reference_tags"
asset_info_id: Mapped[str] = mapped_column( asset_reference_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True String(36),
ForeignKey("asset_references.id", ondelete="CASCADE"),
primary_key=True,
) )
tag_name: Mapped[str] = mapped_column( tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
) )
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual") origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column( added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow DateTime(timezone=False), nullable=False, default=get_utc_now
) )
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links") asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_info_links") tag: Mapped[Tag] = relationship(back_populates="asset_reference_links")
__table_args__ = ( __table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"), Index("ix_asset_reference_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"), Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"),
) )
@ -214,20 +229,18 @@ class Tag(Base):
name: Mapped[str] = mapped_column(String(512), primary_key=True) name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user") tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list[AssetInfoTag]] = relationship( asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
back_populates="tag", back_populates="tag",
overlaps="asset_infos,tags", overlaps="asset_references,tags",
) )
asset_infos: Mapped[list[AssetInfo]] = relationship( asset_references: Mapped[list[AssetReference]] = relationship(
secondary="asset_info_tags", secondary="asset_reference_tags",
back_populates="tags", back_populates="tags",
viewonly=True, viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info", overlaps="asset_reference_links,tag_links,tags,asset_reference",
) )
__table_args__ = ( __table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
Index("ix_tags_tag_type", "tag_type"),
)
def __repr__(self) -> str: def __repr__(self) -> str:
return f"<Tag {self.name}>" return f"<Tag {self.name}>"

View File

@ -1,976 +0,0 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from typing import Sequence
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: dict | None = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@ -0,0 +1,133 @@
from app.assets.database.queries.asset import (
asset_exists_by_hash,
bulk_insert_assets,
get_asset_by_hash,
get_existing_asset_ids,
reassign_asset_references,
update_asset_hash_and_mime,
upsert_asset,
)
from app.assets.database.queries.asset_reference import (
CacheStateRow,
UnenrichedReferenceRow,
bulk_insert_references_ignore_conflicts,
bulk_update_enrichment_level,
bulk_update_is_missing,
bulk_update_needs_verify,
convert_metadata_to_rows,
delete_assets_by_ids,
delete_orphaned_seed_asset,
delete_reference_by_id,
delete_references_by_ids,
fetch_reference_and_asset,
fetch_reference_asset_and_tags,
get_or_create_reference,
get_reference_by_file_path,
get_reference_by_id,
get_reference_with_owner_check,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_references_for_prefixes,
get_unenriched_references,
get_unreferenced_unhashed_asset_ids,
insert_reference,
list_all_file_paths_by_asset_id,
list_references_by_asset_id,
list_references_page,
mark_references_missing_outside_prefixes,
rebuild_metadata_projection,
reference_exists,
reference_exists_for_asset_id,
restore_references_by_paths,
set_reference_metadata,
set_reference_preview,
set_reference_system_metadata,
soft_delete_reference_by_id,
update_reference_access_time,
update_reference_name,
update_is_missing_by_asset_id,
update_reference_timestamps,
update_reference_updated_at,
upsert_reference,
)
from app.assets.database.queries.tags import (
AddTagsResult,
RemoveTagsResult,
SetTagsResult,
add_missing_tag_for_asset_id,
add_tags_to_reference,
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_reference_tags,
list_tag_counts_for_filtered_assets,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_reference,
set_reference_tags,
validate_tags_exist,
)
__all__ = [
"AddTagsResult",
"CacheStateRow",
"RemoveTagsResult",
"SetTagsResult",
"UnenrichedReferenceRow",
"add_missing_tag_for_asset_id",
"add_tags_to_reference",
"asset_exists_by_hash",
"bulk_insert_assets",
"bulk_insert_references_ignore_conflicts",
"bulk_insert_tags_and_meta",
"bulk_update_enrichment_level",
"bulk_update_is_missing",
"bulk_update_needs_verify",
"convert_metadata_to_rows",
"delete_assets_by_ids",
"delete_orphaned_seed_asset",
"delete_reference_by_id",
"delete_references_by_ids",
"ensure_tags_exist",
"fetch_reference_and_asset",
"fetch_reference_asset_and_tags",
"get_asset_by_hash",
"get_existing_asset_ids",
"get_or_create_reference",
"get_reference_by_file_path",
"get_reference_by_id",
"get_reference_with_owner_check",
"get_reference_ids_by_ids",
"get_reference_tags",
"get_references_by_paths_and_asset_ids",
"get_references_for_prefixes",
"get_unenriched_references",
"get_unreferenced_unhashed_asset_ids",
"insert_reference",
"list_all_file_paths_by_asset_id",
"list_references_by_asset_id",
"list_references_page",
"list_tag_counts_for_filtered_assets",
"list_tags_with_usage",
"mark_references_missing_outside_prefixes",
"reassign_asset_references",
"rebuild_metadata_projection",
"reference_exists",
"reference_exists_for_asset_id",
"remove_missing_tag_for_asset_id",
"remove_tags_from_reference",
"restore_references_by_paths",
"set_reference_metadata",
"set_reference_preview",
"set_reference_system_metadata",
"soft_delete_reference_by_id",
"set_reference_tags",
"update_asset_hash_and_mime",
"update_is_missing_by_asset_id",
"update_reference_access_time",
"update_reference_name",
"update_reference_timestamps",
"update_reference_updated_at",
"upsert_asset",
"upsert_reference",
"validate_tags_exist",
]

View File

@ -0,0 +1,140 @@
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks
def asset_exists_by_hash(
session: Session,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True))
.select_from(Asset)
.where(Asset.hash == asset_hash)
.limit(1)
)
).first()
return row is not None
def get_asset_by_hash(
session: Session,
asset_hash: str,
) -> Asset | None:
return (
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
.scalars()
.first()
)
def upsert_asset(
session: Session,
asset_hash: str,
size_bytes: int,
mime_type: str | None = None,
) -> tuple[Asset, bool, bool]:
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
if mime_type:
vals["mime_type"] = mime_type
ins = (
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
.scalars()
.first()
)
if not asset:
raise RuntimeError("Asset row not found after upsert.")
updated = False
if not created:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and not asset.mime_type:
asset.mime_type = mime_type
changed = True
if changed:
updated = True
return asset, created, updated
def bulk_insert_assets(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash."""
if not rows:
return
ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash])
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
session.execute(ins, chunk)
def get_existing_asset_ids(
session: Session,
asset_ids: list[str],
) -> set[str]:
"""Return the subset of asset_ids that exist in the database."""
if not asset_ids:
return set()
found: set[str] = set()
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
rows = session.execute(
select(Asset.id).where(Asset.id.in_(chunk))
).fetchall()
found.update(row[0] for row in rows)
return found
def update_asset_hash_and_mime(
session: Session,
asset_id: str,
asset_hash: str | None = None,
mime_type: str | None = None,
) -> bool:
"""Update asset hash and/or mime_type. Returns True if asset was found."""
asset = session.get(Asset, asset_id)
if not asset:
return False
if asset_hash is not None:
asset.hash = asset_hash
if mime_type is not None and not asset.mime_type:
asset.mime_type = mime_type
return True
def reassign_asset_references(
session: Session,
from_asset_id: str,
to_asset_id: str,
reference_id: str,
) -> None:
"""Reassign a reference from one asset to another.
Used when merging a stub asset into an existing asset with the same hash.
"""
ref = session.get(AssetReference, reference_id)
if ref and ref.asset_id == from_asset_id:
ref.asset_id = to_asset_id
session.flush()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,127 @@
"""Shared utilities for database query modules."""
import os
from decimal import Decimal
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy import exists
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
from app.assets.helpers import escape_sql_like_string, normalize_tags
MAX_BIND_PARAMS = 800
def calculate_rows_per_statement(cols: int) -> int:
"""Calculate how many rows can fit in one statement given column count."""
return max(1, MAX_BIND_PARAMS // max(1, cols))
def iter_chunks(seq, n: int):
"""Yield successive n-sized chunks from seq."""
for i in range(0, len(seq), n):
yield seq[i : i + n]
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
"""Yield chunks of rows sized to fit within bind param limits."""
if not rows:
return
yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row))
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads.
Owner-less rows are visible to everyone.
"""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetReference.owner_id == ""
return AssetReference.owner_id.in_(["", owner_id])
def build_prefix_like_conditions(
prefixes: list[str],
) -> list[sa.sql.ColumnElement]:
"""Build LIKE conditions for matching file paths under directory prefixes."""
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
return conds
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
return sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@ -0,0 +1,418 @@
from dataclasses import dataclass
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import (
Asset,
AssetReference,
AssetReferenceMeta,
AssetReferenceTag,
Tag,
)
from app.assets.database.queries.common import (
apply_metadata_filter,
apply_tag_filters,
build_visible_owner_clause,
iter_row_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
@dataclass(frozen=True)
class AddTagsResult:
added: list[str]
already_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class RemoveTagsResult:
removed: list[str]
not_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class SetTagsResult:
added: list[str]
removed: list[str]
total: list[str]
def validate_tags_exist(session: Session, tags: list[str]) -> None:
"""Raise ValueError if any of the given tag names do not exist."""
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def ensure_tags_exist(
session: Session, names: Iterable[str], tag_type: str = "user"
) -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_reference_tags(session: Session, reference_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
session.execute(
select(AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id == reference_id)
.order_by(AssetReferenceTag.tag_name.asc())
)
).all()
]
def set_reference_tags(
session: Session,
reference_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsResult:
desired = normalize_tags(tags)
current = set(get_reference_tags(session, reference_id))
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all(
[
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
if to_remove:
session.execute(
delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id == reference_id,
AssetReferenceTag.tag_name.in_(to_remove),
)
)
session.flush()
return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
def add_tags_to_reference(
session: Session,
reference_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
reference_row: AssetReference | None = None,
) -> AddTagsResult:
if not reference_row:
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return AddTagsResult(added=[], already_present=[], total_tags=total)
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = set(get_reference_tags(session, reference_id))
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_reference_tags(session, reference_id=reference_id))
return AddTagsResult(
added=sorted(((after - current) & want)),
already_present=sorted(want & current),
total_tags=sorted(after),
)
def remove_tags_from_reference(
session: Session,
reference_id: str,
tags: Sequence[str],
) -> RemoveTagsResult:
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return RemoveTagsResult(removed=[], not_present=[], total_tags=total)
existing = set(get_reference_tags(session, reference_id))
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id == reference_id,
AssetReferenceTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_reference_tags(session, reference_id=reference_id)
return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total)
def add_missing_tag_for_asset_id(
session: Session,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetReference.id.label("asset_reference_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(get_utc_now()).label("added_at"),
)
.where(AssetReference.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == "missing")
)
)
)
)
session.execute(
sqlite.insert(AssetReferenceTag)
.from_select(
["asset_reference_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(
index_elements=[
AssetReferenceTag.asset_reference_id,
AssetReferenceTag.tag_name,
]
)
)
def remove_missing_tag_for_asset_id(
session: Session,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id.in_(
sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id)
),
AssetReferenceTag.tag_name == "missing",
)
)
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetReferenceTag.tag_name.label("tag_name"),
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.select_from(AssetReferenceTag)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
visible_tags_sq = (
select(AssetReferenceTag.tag_name)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name)
)
total_q = total_q.where(Tag.name.in_(visible_tags_sq))
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def list_tag_counts_for_filtered_assets(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
"""Return tag counts for assets matching the given filters.
Uses the same filtering logic as list_references_page but returns
{tag_name: count} instead of paginated references.
"""
# Build a subquery of matching reference IDs
ref_sq = (
select(AssetReference.id)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
ref_sq = ref_sq.subquery()
# Count tags across those references
q = (
select(
AssetReferenceTag.tag_name,
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
.group_by(AssetReferenceTag.tag_name)
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
.limit(limit)
)
rows = session.execute(q).all()
return {tag_name: int(cnt) for tag_name, cnt in rows}
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],
meta_rows: list[dict],
) -> None:
"""Batch insert into asset_reference_tags and asset_reference_meta.
Uses ON CONFLICT DO NOTHING.
Args:
session: Database session
tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at
meta_rows: Dicts with: asset_reference_id, key, ordinal, val_*
"""
if tag_rows:
ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing(
index_elements=[
AssetReferenceTag.asset_reference_id,
AssetReferenceTag.tag_name,
]
)
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
session.execute(ins_tags, chunk)
if meta_rows:
ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing(
index_elements=[
AssetReferenceMeta.asset_reference_id,
AssetReferenceMeta.key,
AssetReferenceMeta.ordinal,
]
)
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):
session.execute(ins_meta, chunk)

View File

@ -1,62 +0,0 @@
from typing import Iterable
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import normalize_tags, utcnow
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
return session.execute(ins)
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sqlalchemy.select(
AssetInfo.id.label("asset_info_id"),
sqlalchemy.literal("missing").label("tag_name"),
sqlalchemy.literal(origin).label("origin"),
sqlalchemy.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sqlalchemy.not_(
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sqlalchemy.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

View File

@ -1,75 +0,0 @@
from blake3 import blake3
from typing import IO
import os
import asyncio
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
# NOTE: this allows hashing different representations of a file-like object
def blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""
Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
# duck typing to check if input is a file-like object
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
"""
Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
# in case file object is already open and not at the beginning, track so can be restored after hashing
orig_pos = file_obj.tell()
try:
# seek to the beginning before reading
if orig_pos != 0:
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
# restore original position in file object, if needed
if orig_pos != 0:
file_obj.seek(orig_pos)

View File

@ -1,226 +1,42 @@
import contextlib
import os import os
from decimal import Decimal
from aiohttp import web
from datetime import datetime, timezone from datetime import datetime, timezone
from pathlib import Path from typing import Sequence
from typing import Literal, Any
import folder_paths
RootType = Literal["models", "input", "output"] def select_best_live_path(states: Sequence) -> str:
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
def get_query_dict(request: web.Request) -> dict[str, Any]:
""" """
Gets a dictionary of query parameters from the request. Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic. 2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
""" """
query_dict = { alive = [
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key) s
for key in request.query.keys() for s in states
} if getattr(s, "file_path", None) and os.path.isfile(s.file_path)
return query_dict ]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: RootType) -> list[str]: def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
if root == "models": """Escapes %, _ and the escape char in a LIKE prefix.
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]: Returns (escaped_prefix, escape_char).
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
""" """
s = s.replace(escape, escape + escape) # escape the escape char first s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape return s, escape
def fast_asset_file_check(
*,
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True
def utcnow() -> datetime: def get_utc_now() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC.""" """Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None) return datetime.now(timezone.utc).replace(tzinfo=None)
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: list[str] | None) -> list[str]: def normalize_tags(tags: list[str] | None) -> list[str]:
""" """
@ -228,85 +44,22 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
- Stripping whitespace and converting to lowercase. - Stripping whitespace and converting to lowercase.
- Removing duplicates. - Removing duplicates.
""" """
return [t.strip().lower() for t in (tags or []) if (t or "").strip()] return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out
def is_scalar(v): def validate_blake3_hash(s: str) -> str:
if v is None: """Validate and normalize a blake3 hash string.
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def project_kv(key: str, value): Returns canonical 'blake3:<hex>' or raises ValueError.
""" """
Turn a metadata key/value into typed projection rows. s = s.strip().lower()
Returns list[dict] with keys: if not s or ":" not in s:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None) raise ValueError("hash must be 'blake3:<hex>'")
""" algo, digest = s.split(":", 1)
rows: list[dict] = [] if (
algo != "blake3"
def _null_row(ordinal: int) -> dict: or len(digest) != 64
return { or any(c for c in digest if c not in "0123456789abcdef")
"key": key, "ordinal": ordinal, ):
"val_str": None, "val_num": None, "val_bool": None, "val_json": None raise ValueError("hash must be 'blake3:<hex>'")
} return f"{algo}:{digest}"
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows

View File

@ -1,516 +0,0 @@
import os
import mimetypes
import contextlib
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out, schemas_in
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
create_asset_info_for_existing_asset,
touch_asset_info_by_id,
update_asset_info_full,
delete_asset_info_by_id,
list_cache_states_by_asset_id,
list_asset_infos_page,
list_tags_with_usage,
get_asset_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
pick_best_live_path,
ingest_fs_asset,
set_asset_info_preview,
)
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
from app.assets.database.models import Asset
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback
def asset_exists(*, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
def get_asset(
*,
asset_info_id: str,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
with create_session() as session:
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: str | None = None,
owner_id: str = "",
expected_asset_hash: str | None = None,
) -> schemas_out.AssetCreated:
"""
Create new asset or update existing asset from a temporary file path.
"""
try:
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
import app.assets.hashing as hashing
digest = hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
with create_session() as session:
result = ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
created_result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
session.commit()
return created_result
def update_asset(
*,
asset_info_id: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
result = schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
session.commit()
return result
def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
result = schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
session.commit()
return result
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetCreated | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
session.commit()
return result
def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return schemas_out.TagsAdd(**data)
def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return schemas_out.TagsRemove(**data)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)

View File

@ -1,263 +1,567 @@
import contextlib
import time
import logging import logging
import os import os
import sqlalchemy from pathlib import Path
from typing import Callable, Literal, TypedDict
import folder_paths import folder_paths
from app.database.db import create_session, dependencies_available from app.assets.database.queries import (
from app.assets.helpers import ( add_missing_tag_for_asset_id,
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path, bulk_update_enrichment_level,
list_tree,prefixes_for_root, escape_like_prefix, bulk_update_is_missing,
RootType bulk_update_needs_verify,
delete_orphaned_seed_asset,
delete_references_by_ids,
ensure_tags_exist,
get_asset_by_hash,
get_references_for_prefixes,
get_unenriched_references,
mark_references_missing_outside_prefixes,
reassign_asset_references,
remove_missing_tag_for_asset_id,
set_reference_system_metadata,
update_asset_hash_and_mime,
) )
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id from app.assets.services.bulk_ingest import (
from app.assets.database.bulk_ops import seed_from_paths_batch SeedAssetSpec,
from app.assets.database.models import Asset, AssetCacheState, AssetInfo batch_insert_seed_assets,
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
"""
Scan the given roots and seed the assets into the database.
"""
if not dependencies_available():
if enable_logging:
logging.warning("Database dependencies not available, skipping assets scan")
return
t_start = time.perf_counter()
created = 0
skipped_existing = 0
orphans_pruned = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
try:
orphans_pruned = _prune_orphaned_assets(roots)
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped_existing += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
continue
# skip empty files
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
}
) )
for t in tags: from app.assets.services.file_utils import (
tag_pool.add(t) get_mtime_ns,
# if no file specs, nothing to do is_visible,
if not specs: list_files_recursively,
return verify_file_unchanged,
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
sess.commit()
finally:
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
orphans_pruned,
len(paths),
) )
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
)
from app.database.db import create_session
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int: class _RefInfo(TypedDict):
"""Prune cache states outside configured prefixes, then delete orphaned seed assets.""" ref_id: str
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)] file_path: str
if not all_prefixes: exists: bool
return 0 stat_unchanged: bool
needs_verify: bool
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_like_prefix(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
orphan_subq = (
sqlalchemy.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
).scalar_subquery()
with create_session() as sess:
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
sess.commit()
return result.rowcount
def _fast_db_consistency_pass( class _AssetAccumulator(TypedDict):
hash: str | None
size_db: int
refs: list[_RefInfo]
RootType = Literal["models", "input", "output"]
def get_prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def get_all_known_prefixes() -> list[str]:
"""Get all known asset prefixes across all root types."""
all_roots: tuple[RootType, ...] = ("models", "input", "output")
return [p for root in all_roots for p in get_prefixes_for_root(root)]
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
if not all(is_visible(part) for part in Path(rel_path).parts):
continue
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
abs_p = Path(abs_path)
for b in bases:
if abs_p.is_relative_to(os.path.abspath(b)):
allowed = True
break
if allowed:
out.append(abs_path)
return out
def sync_references_with_filesystem(
session,
root: RootType, root: RootType,
*,
collect_existing_paths: bool = False, collect_existing_paths: bool = False,
update_missing_tags: bool = False, update_missing_tags: bool = False,
) -> set[str] | None: ) -> set[str] | None:
"""Fast DB+FS pass for a root: """Reconcile asset references with filesystem for a root.
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states - Toggle needs_verify per reference using mtime/size stat check
- For seed assets with all states missing: delete Asset and its AssetInfos - For hashed assets with at least one stat-unchanged ref: delete stale missing refs
- Optionally add/remove 'missing' tags based on fast-ok in this root - For seed assets with all refs missing: delete Asset and its references
- Optionally add/remove 'missing' tags based on stat check in this root
- Optionally return surviving absolute paths - Optionally return surviving absolute paths
Args:
session: Database session
root: Root type to scan
collect_existing_paths: If True, return set of surviving file paths
update_missing_tags: If True, update 'missing' tags based on file status
Returns:
Set of surviving absolute paths if collect_existing_paths=True, else None
""" """
prefixes = prefixes_for_root(root) prefixes = get_prefixes_for_root(root)
if not prefixes: if not prefixes:
return set() if collect_existing_paths else None return set() if collect_existing_paths else None
conds = [] rows = get_references_for_prefixes(
for p in prefixes: session, prefixes, include_missing=update_missing_tags
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
with create_session() as sess:
rows = (
sess.execute(
sqlalchemy.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
) )
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sqlalchemy.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
)
).all()
by_asset: dict[str, dict] = {} by_asset: dict[str, _AssetAccumulator] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows: for row in rows:
acc = by_asset.get(aid) acc = by_asset.get(row.asset_id)
if acc is None: if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []} acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
by_asset[aid] = acc by_asset[row.asset_id] = acc
fast_ok = False stat_unchanged = False
try: try:
exists = True exists = True
fast_ok = fast_asset_file_check( stat_unchanged = verify_file_unchanged(
mtime_db=mtime_db, mtime_db=row.mtime_ns,
size_db=acc["size_db"], size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True), stat_result=os.stat(row.file_path, follow_symlinks=True),
) )
except FileNotFoundError: except FileNotFoundError:
exists = False exists = False
except OSError: except PermissionError:
exists = True
logging.debug("Permission denied accessing %s", row.file_path)
except OSError as e:
exists = False exists = False
logging.debug("OSError checking %s: %s", row.file_path, e)
acc["states"].append({ acc["refs"].append(
"sid": sid, {
"fp": fp, "ref_id": row.reference_id,
"file_path": row.file_path,
"exists": exists, "exists": exists,
"fast_ok": fast_ok, "stat_unchanged": stat_unchanged,
"needs_verify": bool(needs_verify), "needs_verify": row.needs_verify,
}) }
)
to_set_verify: list[int] = [] to_set_verify: list[str] = []
to_clear_verify: list[int] = [] to_clear_verify: list[str] = []
stale_state_ids: list[int] = [] stale_ref_ids: list[str] = []
to_mark_missing: list[str] = []
to_clear_missing: list[str] = []
survivors: set[str] = set() survivors: set[str] = set()
for aid, acc in by_asset.items(): for aid, acc in by_asset.items():
a_hash = acc["hash"] a_hash = acc["hash"]
states = acc["states"] refs = acc["refs"]
any_fast_ok = any(s["fast_ok"] for s in states) any_unchanged = any(r["stat_unchanged"] for r in refs)
all_missing = all(not s["exists"] for s in states) all_missing = all(not r["exists"] for r in refs)
for s in states: for r in refs:
if not s["exists"]: if not r["exists"]:
to_mark_missing.append(r["ref_id"])
continue continue
if s["fast_ok"] and s["needs_verify"]: if r["stat_unchanged"]:
to_clear_verify.append(s["sid"]) to_clear_missing.append(r["ref_id"])
if not s["fast_ok"] and not s["needs_verify"]: if r["needs_verify"]:
to_set_verify.append(s["sid"]) to_clear_verify.append(r["ref_id"])
if not r["stat_unchanged"] and not r["needs_verify"]:
to_set_verify.append(r["ref_id"])
if a_hash is None: if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists if refs and all_missing:
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid)) delete_orphaned_seed_asset(session, aid)
asset = sess.get(Asset, aid)
if asset:
sess.delete(asset)
else: else:
for s in states: for r in refs:
if s["exists"]: if r["exists"]:
survivors.add(os.path.abspath(s["fp"])) survivors.add(os.path.abspath(r["file_path"]))
continue continue
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records if any_unchanged:
for s in states: for r in refs:
if not s["exists"]: if not r["exists"]:
stale_state_ids.append(s["sid"]) stale_ref_ids.append(r["ref_id"])
if update_missing_tags: if update_missing_tags:
with contextlib.suppress(Exception): try:
remove_missing_tag_for_asset_id(sess, asset_id=aid) remove_missing_tag_for_asset_id(session, asset_id=aid)
elif update_missing_tags: except Exception as e:
with contextlib.suppress(Exception): logging.warning(
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic") "Failed to remove missing tag for asset %s: %s", aid, e
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
) )
if to_clear_verify: elif update_missing_tags:
sess.execute( try:
sqlalchemy.update(AssetCacheState) add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
.where(AssetCacheState.id.in_(to_clear_verify)) except Exception as e:
.values(needs_verify=False) logging.warning("Failed to add missing tag for asset %s: %s", aid, e)
for r in refs:
if r["exists"]:
survivors.add(os.path.abspath(r["file_path"]))
delete_references_by_ids(session, stale_ref_ids)
stale_set = set(stale_ref_ids)
to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set]
bulk_update_is_missing(session, to_mark_missing, value=True)
bulk_update_is_missing(session, to_clear_missing, value=False)
bulk_update_needs_verify(session, to_set_verify, value=True)
bulk_update_needs_verify(session, to_clear_verify, value=False)
return survivors if collect_existing_paths else None
def sync_root_safely(root: RootType) -> set[str]:
"""Sync a single root's references with the filesystem.
Returns survivors (existing paths) or empty set on failure.
"""
try:
with create_session() as sess:
survivors = sync_references_with_filesystem(
sess,
root,
collect_existing_paths=True,
update_missing_tags=True,
) )
sess.commit() sess.commit()
return survivors if collect_existing_paths else None return survivors or set()
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", root, e)
return set()
def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int:
"""Mark references as missing when outside the given prefixes.
This is a non-destructive soft-delete. Returns count marked or 0 on failure.
"""
try:
with create_session() as sess:
count = mark_references_missing_outside_prefixes(sess, prefixes)
sess.commit()
return count
except Exception as e:
logging.exception("marking missing assets failed: %s", e)
return 0
def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
"""Collect all file paths for the given roots."""
paths: list[str] = []
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_files_recursively(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_files_recursively(folder_paths.get_output_directory()))
return paths
def build_asset_specs(
paths: list[str],
existing_paths: set[str],
enable_metadata_extraction: bool = True,
compute_hashes: bool = False,
) -> tuple[list[SeedAssetSpec], set[str], int]:
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count).
Args:
paths: List of file paths to process
existing_paths: Set of paths that already exist in the database
enable_metadata_extraction: If True, extract tier 1 & 2 metadata
compute_hashes: If True, compute blake3 hashes (slow for large files)
"""
specs: list[SeedAssetSpec] = []
tag_pool: set[str] = set()
skipped = 0
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=True)
except OSError:
continue
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
rel_fname = compute_relative_filename(abs_p)
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
metadata = None
if enable_metadata_extraction:
metadata = extract_file_metadata(
abs_p,
stat_result=stat_p,
relative_filename=rel_fname,
)
# Compute hash if requested
asset_hash: str | None = None
if compute_hashes:
try:
digest, _ = compute_blake3_hash(abs_p)
asset_hash = "blake3:" + digest
except Exception as e:
logging.warning("Failed to hash %s: %s", abs_p, e)
mime_type = metadata.content_type if metadata else None
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": get_mtime_ns(stat_p),
"info_name": name,
"tags": tags,
"fname": rel_fname,
"metadata": metadata,
"hash": asset_hash,
"mime_type": mime_type,
}
)
tag_pool.update(tags)
return specs, tag_pool, skipped
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created refs."""
if not specs:
return 0
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
sess.commit()
return result.inserted_refs
# Enrichment level constants
ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only
ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type)
ENRICHMENT_HASHED = 2 # Hash computed (blake3)
def get_unenriched_assets_for_roots(
roots: tuple[RootType, ...],
max_level: int = ENRICHMENT_STUB,
limit: int = 1000,
) -> list:
"""Get assets that need enrichment for the given roots.
Args:
roots: Tuple of root types to scan
max_level: Maximum enrichment level to include
limit: Maximum number of rows to return
Returns:
List of UnenrichedReferenceRow
"""
prefixes: list[str] = []
for root in roots:
prefixes.extend(get_prefixes_for_root(root))
if not prefixes:
return []
with create_session() as sess:
return get_unenriched_references(
sess, prefixes, max_level=max_level, limit=limit
)
def enrich_asset(
session,
file_path: str,
reference_id: str,
asset_id: str,
extract_metadata: bool = True,
compute_hash: bool = False,
interrupt_check: Callable[[], bool] | None = None,
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
) -> int:
"""Enrich a single asset with metadata and/or hash.
Args:
session: Database session (caller manages lifecycle)
file_path: Absolute path to the file
reference_id: ID of the reference to update
asset_id: ID of the asset to update (for mime_type and hash)
extract_metadata: If True, extract safetensors header and mime type
compute_hash: If True, compute blake3 hash
interrupt_check: Optional non-blocking callable that returns True if
the operation should be interrupted (e.g. paused or cancelled)
hash_checkpoints: Optional dict for saving/restoring hash progress
across interruptions, keyed by file path
Returns:
New enrichment level achieved
"""
new_level = ENRICHMENT_STUB
try:
stat_p = os.stat(file_path, follow_symlinks=True)
except OSError:
return new_level
rel_fname = compute_relative_filename(file_path)
mime_type: str | None = None
metadata = None
if extract_metadata:
metadata = extract_file_metadata(
file_path,
stat_result=stat_p,
relative_filename=rel_fname,
)
if metadata:
mime_type = metadata.content_type
new_level = ENRICHMENT_METADATA
full_hash: str | None = None
if compute_hash:
try:
mtime_before = get_mtime_ns(stat_p)
size_before = stat_p.st_size
# Restore checkpoint if available and file unchanged
checkpoint = None
if hash_checkpoints is not None:
checkpoint = hash_checkpoints.get(file_path)
if checkpoint is not None:
cur_stat = os.stat(file_path, follow_symlinks=True)
if (checkpoint.mtime_ns != get_mtime_ns(cur_stat)
or checkpoint.file_size != cur_stat.st_size):
checkpoint = None
hash_checkpoints.pop(file_path, None)
else:
mtime_before = get_mtime_ns(cur_stat)
digest, new_checkpoint = compute_blake3_hash(
file_path,
interrupt_check=interrupt_check,
checkpoint=checkpoint,
)
if digest is None:
# Interrupted — save checkpoint for later resumption
if hash_checkpoints is not None and new_checkpoint is not None:
new_checkpoint.mtime_ns = mtime_before
new_checkpoint.file_size = size_before
hash_checkpoints[file_path] = new_checkpoint
return new_level
# Completed — clear any saved checkpoint
if hash_checkpoints is not None:
hash_checkpoints.pop(file_path, None)
stat_after = os.stat(file_path, follow_symlinks=True)
mtime_after = get_mtime_ns(stat_after)
if mtime_before != mtime_after:
logging.warning("File modified during hashing, discarding hash: %s", file_path)
else:
full_hash = f"blake3:{digest}"
metadata_ok = not extract_metadata or metadata is not None
if metadata_ok:
new_level = ENRICHMENT_HASHED
except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e)
if extract_metadata and metadata:
system_metadata = metadata.to_user_metadata()
set_reference_system_metadata(session, reference_id, system_metadata)
if full_hash:
existing = get_asset_by_hash(session, full_hash)
if existing and existing.id != asset_id:
reassign_asset_references(session, asset_id, existing.id, reference_id)
delete_orphaned_seed_asset(session, asset_id)
if mime_type:
update_asset_hash_and_mime(session, existing.id, mime_type=mime_type)
else:
update_asset_hash_and_mime(session, asset_id, full_hash, mime_type)
elif mime_type:
update_asset_hash_and_mime(session, asset_id, mime_type=mime_type)
bulk_update_enrichment_level(session, [reference_id], new_level)
session.commit()
return new_level
def enrich_assets_batch(
rows: list,
extract_metadata: bool = True,
compute_hash: bool = False,
interrupt_check: Callable[[], bool] | None = None,
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
) -> tuple[int, list[str]]:
"""Enrich a batch of assets.
Uses a single DB session for the entire batch, committing after each
individual asset to avoid long-held transactions while eliminating
per-asset session creation overhead.
Args:
rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots
extract_metadata: If True, extract metadata for each asset
compute_hash: If True, compute hash for each asset
interrupt_check: Optional non-blocking callable that returns True if
the operation should be interrupted (e.g. paused or cancelled)
hash_checkpoints: Optional dict for saving/restoring hash progress
across interruptions, keyed by file path
Returns:
Tuple of (enriched_count, failed_reference_ids)
"""
enriched = 0
failed_ids: list[str] = []
with create_session() as sess:
for row in rows:
if interrupt_check is not None and interrupt_check():
break
try:
new_level = enrich_asset(
sess,
file_path=row.file_path,
reference_id=row.reference_id,
asset_id=row.asset_id,
extract_metadata=extract_metadata,
compute_hash=compute_hash,
interrupt_check=interrupt_check,
hash_checkpoints=hash_checkpoints,
)
if new_level > row.enrichment_level:
enriched += 1
else:
failed_ids.append(row.reference_id)
except Exception as e:
logging.warning("Failed to enrich %s: %s", row.file_path, e)
sess.rollback()
failed_ids.append(row.reference_id)
return enriched, failed_ids

794
app/assets/seeder.py Normal file
View File

@ -0,0 +1,794 @@
"""Background asset seeder with thread management and cancellation support."""
import logging
import os
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Callable
from app.assets.scanner import (
ENRICHMENT_METADATA,
ENRICHMENT_STUB,
RootType,
build_asset_specs,
collect_paths_for_roots,
enrich_assets_batch,
get_all_known_prefixes,
get_prefixes_for_root,
get_unenriched_assets_for_roots,
insert_asset_specs,
mark_missing_outside_prefixes_safely,
sync_root_safely,
)
from app.database.db import dependencies_available
class ScanInProgressError(Exception):
"""Raised when an operation cannot proceed because a scan is running."""
class State(Enum):
"""Seeder state machine states."""
IDLE = "IDLE"
RUNNING = "RUNNING"
PAUSED = "PAUSED"
CANCELLING = "CANCELLING"
class ScanPhase(Enum):
"""Scan phase options."""
FAST = "fast" # Phase 1: filesystem only (stubs)
ENRICH = "enrich" # Phase 2: metadata + hash
FULL = "full" # Both phases sequentially
@dataclass
class Progress:
"""Progress information for a scan operation."""
scanned: int = 0
total: int = 0
created: int = 0
skipped: int = 0
@dataclass
class ScanStatus:
"""Current status of the asset seeder."""
state: State
progress: Progress | None
errors: list[str] = field(default_factory=list)
ProgressCallback = Callable[[Progress], None]
class _AssetSeeder:
"""Background asset scanning manager.
Spawns ephemeral daemon threads for scanning.
Each scan creates a new thread that exits when complete.
Use the module-level ``asset_seeder`` instance.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._state = State.IDLE
self._progress: Progress | None = None
self._last_progress: Progress | None = None
self._errors: list[str] = []
self._thread: threading.Thread | None = None
self._cancel_event = threading.Event()
self._run_gate = threading.Event()
self._run_gate.set() # Start unpaused (set = running, clear = paused)
self._roots: tuple[RootType, ...] = ()
self._phase: ScanPhase = ScanPhase.FULL
self._compute_hashes: bool = False
self._prune_first: bool = False
self._progress_callback: ProgressCallback | None = None
self._disabled: bool = False
def disable(self) -> None:
"""Disable the asset seeder, preventing any scans from starting."""
self._disabled = True
logging.info("Asset seeder disabled")
def is_disabled(self) -> bool:
"""Check if the asset seeder is disabled."""
return self._disabled
def start(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
phase: ScanPhase = ScanPhase.FULL,
progress_callback: ProgressCallback | None = None,
prune_first: bool = False,
compute_hashes: bool = False,
) -> bool:
"""Start a background scan for the given roots.
Args:
roots: Tuple of root types to scan (models, input, output)
phase: Scan phase to run (FAST, ENRICH, or FULL for both)
progress_callback: Optional callback called with progress updates
prune_first: If True, prune orphaned assets before scanning
compute_hashes: If True, compute blake3 hashes (slow)
Returns:
True if scan was started, False if already running
"""
if self._disabled:
logging.debug("Asset seeder is disabled, skipping start")
return False
logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value)
with self._lock:
if self._state != State.IDLE:
logging.info("Asset seeder already running, skipping start")
return False
self._state = State.RUNNING
self._progress = Progress()
self._errors = []
self._roots = roots
self._phase = phase
self._prune_first = prune_first
self._compute_hashes = compute_hashes
self._progress_callback = progress_callback
self._cancel_event.clear()
self._run_gate.set() # Ensure unpaused when starting
self._thread = threading.Thread(
target=self._run_scan,
name="_AssetSeeder",
daemon=True,
)
self._thread.start()
return True
def start_fast(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
progress_callback: ProgressCallback | None = None,
prune_first: bool = False,
) -> bool:
"""Start a fast scan (phase 1 only) - creates stub records.
Args:
roots: Tuple of root types to scan
progress_callback: Optional callback for progress updates
prune_first: If True, prune orphaned assets before scanning
Returns:
True if scan was started, False if already running
"""
return self.start(
roots=roots,
phase=ScanPhase.FAST,
progress_callback=progress_callback,
prune_first=prune_first,
compute_hashes=False,
)
def start_enrich(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
progress_callback: ProgressCallback | None = None,
compute_hashes: bool = False,
) -> bool:
"""Start an enrichment scan (phase 2 only) - extracts metadata and hashes.
Args:
roots: Tuple of root types to scan
progress_callback: Optional callback for progress updates
compute_hashes: If True, compute blake3 hashes
Returns:
True if scan was started, False if already running
"""
return self.start(
roots=roots,
phase=ScanPhase.ENRICH,
progress_callback=progress_callback,
prune_first=False,
compute_hashes=compute_hashes,
)
def cancel(self) -> bool:
"""Request cancellation of the current scan.
Returns:
True if cancellation was requested, False if not running or paused
"""
with self._lock:
if self._state not in (State.RUNNING, State.PAUSED):
return False
logging.info("Asset seeder cancelling (was %s)", self._state.value)
self._state = State.CANCELLING
self._cancel_event.set()
self._run_gate.set() # Unblock if paused so thread can exit
return True
def stop(self) -> bool:
"""Stop the current scan (alias for cancel).
Returns:
True if stop was requested, False if not running
"""
return self.cancel()
def pause(self) -> bool:
"""Pause the current scan.
The scan will complete its current batch before pausing.
Returns:
True if pause was requested, False if not running
"""
with self._lock:
if self._state != State.RUNNING:
return False
logging.info("Asset seeder pausing")
self._state = State.PAUSED
self._run_gate.clear()
return True
def resume(self) -> bool:
"""Resume a paused scan.
This is a noop if the scan is not in the PAUSED state
Returns:
True if resumed, False if not paused
"""
with self._lock:
if self._state != State.PAUSED:
return False
logging.info("Asset seeder resuming")
self._state = State.RUNNING
self._run_gate.set()
self._emit_event("assets.seed.resumed", {})
return True
def restart(
self,
roots: tuple[RootType, ...] | None = None,
phase: ScanPhase | None = None,
progress_callback: ProgressCallback | None = None,
prune_first: bool | None = None,
compute_hashes: bool | None = None,
timeout: float = 5.0,
) -> bool:
"""Cancel any running scan and start a new one.
Args:
roots: Roots to scan (defaults to previous roots)
phase: Scan phase (defaults to previous phase)
progress_callback: Progress callback (defaults to previous)
prune_first: Prune before scan (defaults to previous)
compute_hashes: Compute hashes (defaults to previous)
timeout: Max seconds to wait for current scan to stop
Returns:
True if new scan was started, False if failed to stop previous
"""
logging.info("Asset seeder restart requested")
with self._lock:
prev_roots = self._roots
prev_phase = self._phase
prev_callback = self._progress_callback
prev_prune = self._prune_first
prev_hashes = self._compute_hashes
self.cancel()
if not self.wait(timeout=timeout):
return False
cb = progress_callback if progress_callback is not None else prev_callback
return self.start(
roots=roots if roots is not None else prev_roots,
phase=phase if phase is not None else prev_phase,
progress_callback=cb,
prune_first=prune_first if prune_first is not None else prev_prune,
compute_hashes=(
compute_hashes if compute_hashes is not None else prev_hashes
),
)
def wait(self, timeout: float | None = None) -> bool:
"""Wait for the current scan to complete.
Args:
timeout: Maximum seconds to wait, or None for no timeout
Returns:
True if scan completed, False if timeout expired or no scan running
"""
with self._lock:
thread = self._thread
if thread is None:
return True
thread.join(timeout=timeout)
return not thread.is_alive()
def get_status(self) -> ScanStatus:
"""Get the current status and progress of the seeder."""
with self._lock:
src = self._progress or self._last_progress
return ScanStatus(
state=self._state,
progress=Progress(
scanned=src.scanned,
total=src.total,
created=src.created,
skipped=src.skipped,
)
if src
else None,
errors=list(self._errors),
)
def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown: cancel any running scan and wait for thread.
Args:
timeout: Maximum seconds to wait for thread to exit
"""
self.cancel()
self.wait(timeout=timeout)
with self._lock:
self._thread = None
def mark_missing_outside_prefixes(self) -> int:
"""Mark references as missing when outside all known root prefixes.
This is a non-destructive soft-delete operation. Assets and their
metadata are preserved, but references are flagged as missing.
They can be restored if the file reappears in a future scan.
This operation is decoupled from scanning to prevent partial scans
from accidentally marking assets belonging to other roots.
Should be called explicitly when cleanup is desired, typically after
a full scan of all roots or during maintenance.
Returns:
Number of references marked as missing
Raises:
ScanInProgressError: If a scan is currently running
"""
with self._lock:
if self._state != State.IDLE:
raise ScanInProgressError(
"Cannot mark missing assets while scan is running"
)
self._state = State.RUNNING
try:
if not dependencies_available():
logging.warning(
"Database dependencies not available, skipping mark missing"
)
return 0
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d references as missing", marked)
return marked
finally:
with self._lock:
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
def _is_cancelled(self) -> bool:
"""Check if cancellation has been requested."""
return self._cancel_event.is_set()
def _is_paused_or_cancelled(self) -> bool:
"""Non-blocking check: True if paused or cancelled.
Use as interrupt_check for I/O-bound work (e.g. hashing) so that
file handles are released immediately on pause rather than held
open while blocked. The caller is responsible for blocking on
_check_pause_and_cancel() afterward.
"""
return not self._run_gate.is_set() or self._cancel_event.is_set()
def _check_pause_and_cancel(self) -> bool:
"""Block while paused, then check if cancelled.
Call this at checkpoint locations in scan loops. It will:
1. Block indefinitely while paused (until resume or cancel)
2. Return True if cancelled, False to continue
Returns:
True if scan should stop, False to continue
"""
if not self._run_gate.is_set():
self._emit_event("assets.seed.paused", {})
self._run_gate.wait() # Blocks if paused
return self._is_cancelled()
def _emit_event(self, event_type: str, data: dict) -> None:
"""Emit a WebSocket event if server is available."""
try:
from server import PromptServer
if hasattr(PromptServer, "instance") and PromptServer.instance:
PromptServer.instance.send_sync(event_type, data)
except Exception:
pass
def _update_progress(
self,
scanned: int | None = None,
total: int | None = None,
created: int | None = None,
skipped: int | None = None,
) -> None:
"""Update progress counters (thread-safe)."""
callback: ProgressCallback | None = None
progress: Progress | None = None
with self._lock:
if self._progress is None:
return
if scanned is not None:
self._progress.scanned = scanned
if total is not None:
self._progress.total = total
if created is not None:
self._progress.created = created
if skipped is not None:
self._progress.skipped = skipped
if self._progress_callback:
callback = self._progress_callback
progress = Progress(
scanned=self._progress.scanned,
total=self._progress.total,
created=self._progress.created,
skipped=self._progress.skipped,
)
if callback and progress:
try:
callback(progress)
except Exception:
pass
_MAX_ERRORS = 200
def _add_error(self, message: str) -> None:
"""Add an error message (thread-safe), capped at _MAX_ERRORS."""
with self._lock:
if len(self._errors) < self._MAX_ERRORS:
self._errors.append(message)
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
"""Log the directories that will be scanned."""
import folder_paths
for root in roots:
if root == "models":
logging.info(
"Asset scan [models] directory: %s",
os.path.abspath(folder_paths.models_dir),
)
else:
prefixes = get_prefixes_for_root(root)
if prefixes:
logging.info("Asset scan [%s] directories: %s", root, prefixes)
def _run_scan(self) -> None:
"""Main scan loop running in background thread."""
t_start = time.perf_counter()
roots = self._roots
phase = self._phase
cancelled = False
total_created = 0
total_enriched = 0
skipped_existing = 0
total_paths = 0
try:
if not dependencies_available():
self._add_error("Database dependencies not available")
self._emit_event(
"assets.seed.error",
{"message": "Database dependencies not available"},
)
return
if self._prune_first:
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d refs as missing before scan", marked)
if self._check_pause_and_cancel():
logging.info("Asset scan cancelled after pruning phase")
cancelled = True
return
self._log_scan_config(roots)
# Phase 1: Fast scan (stub records)
if phase in (ScanPhase.FAST, ScanPhase.FULL):
created, skipped, paths = self._run_fast_phase(roots)
total_created, skipped_existing, total_paths = created, skipped, paths
if self._check_pause_and_cancel():
cancelled = True
return
self._emit_event(
"assets.seed.fast_complete",
{
"roots": list(roots),
"created": total_created,
"skipped": skipped_existing,
"total": total_paths,
},
)
# Phase 2: Enrichment scan (metadata + hashes)
if phase in (ScanPhase.ENRICH, ScanPhase.FULL):
if self._check_pause_and_cancel():
cancelled = True
return
enrich_cancelled, total_enriched = self._run_enrich_phase(roots)
if enrich_cancelled:
cancelled = True
return
self._emit_event(
"assets.seed.enrich_complete",
{
"roots": list(roots),
"enriched": total_enriched,
},
)
elapsed = time.perf_counter() - t_start
logging.info(
"Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d",
roots,
phase.value,
elapsed,
total_created,
total_enriched,
skipped_existing,
)
self._emit_event(
"assets.seed.completed",
{
"phase": phase.value,
"total": total_paths,
"created": total_created,
"enriched": total_enriched,
"skipped": skipped_existing,
"elapsed": round(elapsed, 3),
},
)
except Exception as e:
self._add_error(f"Scan failed: {e}")
logging.exception("Asset scan failed")
self._emit_event("assets.seed.error", {"message": str(e)})
finally:
if cancelled:
self._emit_event(
"assets.seed.cancelled",
{
"scanned": self._progress.scanned if self._progress else 0,
"total": total_paths,
"created": total_created,
},
)
with self._lock:
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
"""Run phase 1: fast scan to create stub records.
Returns:
Tuple of (total_created, skipped_existing, total_paths)
"""
t_fast_start = time.perf_counter()
total_created = 0
skipped_existing = 0
existing_paths: set[str] = set()
t_sync = time.perf_counter()
for r in roots:
if self._check_pause_and_cancel():
return total_created, skipped_existing, 0
existing_paths.update(sync_root_safely(r))
logging.debug(
"Fast scan: sync_root phase took %.3fs (%d existing paths)",
time.perf_counter() - t_sync,
len(existing_paths),
)
if self._check_pause_and_cancel():
return total_created, skipped_existing, 0
t_collect = time.perf_counter()
paths = collect_paths_for_roots(roots)
logging.debug(
"Fast scan: collect_paths took %.3fs (%d paths found)",
time.perf_counter() - t_collect,
len(paths),
)
total_paths = len(paths)
self._update_progress(total=total_paths)
self._emit_event(
"assets.seed.started",
{"roots": list(roots), "total": total_paths, "phase": "fast"},
)
# Use stub specs (no metadata extraction, no hashing)
t_specs = time.perf_counter()
specs, tag_pool, skipped_existing = build_asset_specs(
paths,
existing_paths,
enable_metadata_extraction=False,
compute_hashes=False,
)
logging.debug(
"Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)",
time.perf_counter() - t_specs,
len(specs),
skipped_existing,
)
self._update_progress(skipped=skipped_existing)
if self._check_pause_and_cancel():
return total_created, skipped_existing, total_paths
batch_size = 500
last_progress_time = time.perf_counter()
progress_interval = 1.0
for i in range(0, len(specs), batch_size):
if self._check_pause_and_cancel():
logging.info(
"Fast scan cancelled after %d/%d files (created=%d)",
i,
len(specs),
total_created,
)
return total_created, skipped_existing, total_paths
batch = specs[i : i + batch_size]
batch_tags = {t for spec in batch for t in spec["tags"]}
try:
created = insert_asset_specs(batch, batch_tags)
total_created += created
except Exception as e:
self._add_error(f"Batch insert failed at offset {i}: {e}")
logging.exception("Batch insert failed at offset %d", i)
scanned = i + len(batch)
now = time.perf_counter()
self._update_progress(scanned=scanned, created=total_created)
if now - last_progress_time >= progress_interval:
self._emit_event(
"assets.seed.progress",
{
"phase": "fast",
"scanned": scanned,
"total": len(specs),
"created": total_created,
},
)
last_progress_time = now
self._update_progress(scanned=len(specs), created=total_created)
logging.info(
"Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)",
time.perf_counter() - t_fast_start,
total_created,
skipped_existing,
total_paths,
)
return total_created, skipped_existing, total_paths
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
"""Run phase 2: enrich existing records with metadata and hashes.
Returns:
Tuple of (cancelled, total_enriched)
"""
total_enriched = 0
batch_size = 100
last_progress_time = time.perf_counter()
progress_interval = 1.0
# Get the target enrichment level based on compute_hashes
if not self._compute_hashes:
target_max_level = ENRICHMENT_STUB
else:
target_max_level = ENRICHMENT_METADATA
self._emit_event(
"assets.seed.started",
{"roots": list(roots), "phase": "enrich"},
)
skip_ids: set[str] = set()
consecutive_empty = 0
max_consecutive_empty = 3
# Hash checkpoints survive across batches so interrupted hashes
# can be resumed without re-reading the entire file.
hash_checkpoints: dict[str, object] = {}
while True:
if self._check_pause_and_cancel():
logging.info("Enrich scan cancelled after %d assets", total_enriched)
return True, total_enriched
# Fetch next batch of unenriched assets
unenriched = get_unenriched_assets_for_roots(
roots,
max_level=target_max_level,
limit=batch_size,
)
# Filter out previously failed references
if skip_ids:
unenriched = [r for r in unenriched if r.reference_id not in skip_ids]
if not unenriched:
break
enriched, failed_ids = enrich_assets_batch(
unenriched,
extract_metadata=True,
compute_hash=self._compute_hashes,
interrupt_check=self._is_paused_or_cancelled,
hash_checkpoints=hash_checkpoints,
)
total_enriched += enriched
skip_ids.update(failed_ids)
if enriched == 0:
consecutive_empty += 1
if consecutive_empty >= max_consecutive_empty:
logging.warning(
"Enrich phase stopping: %d consecutive batches with no progress (%d skipped)",
consecutive_empty,
len(skip_ids),
)
break
else:
consecutive_empty = 0
now = time.perf_counter()
if now - last_progress_time >= progress_interval:
self._emit_event(
"assets.seed.progress",
{
"phase": "enrich",
"enriched": total_enriched,
},
)
last_progress_time = now
return False, total_enriched
asset_seeder = _AssetSeeder()

View File

@ -0,0 +1,87 @@
from app.assets.services.asset_management import (
asset_exists,
delete_asset_reference,
get_asset_by_hash,
get_asset_detail,
list_assets_page,
resolve_asset_for_download,
set_asset_preview,
update_asset_metadata,
)
from app.assets.services.bulk_ingest import (
BulkInsertResult,
batch_insert_seed_assets,
cleanup_unreferenced_assets,
)
from app.assets.services.file_utils import (
get_mtime_ns,
get_size_and_mtime_ns,
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.ingest import (
DependencyMissingError,
HashMismatchError,
create_from_hash,
upload_from_temp_path,
)
from app.assets.database.queries import (
AddTagsResult,
RemoveTagsResult,
)
from app.assets.services.schemas import (
AssetData,
AssetDetailResult,
AssetSummaryData,
DownloadResolutionResult,
IngestResult,
ListAssetsResult,
ReferenceData,
RegisterAssetResult,
TagUsage,
UploadResult,
UserMetadata,
)
from app.assets.services.tagging import (
apply_tags,
list_tags,
remove_tags,
)
__all__ = [
"AddTagsResult",
"AssetData",
"AssetDetailResult",
"AssetSummaryData",
"ReferenceData",
"BulkInsertResult",
"DependencyMissingError",
"DownloadResolutionResult",
"HashMismatchError",
"IngestResult",
"ListAssetsResult",
"RegisterAssetResult",
"RemoveTagsResult",
"TagUsage",
"UploadResult",
"UserMetadata",
"apply_tags",
"asset_exists",
"batch_insert_seed_assets",
"create_from_hash",
"delete_asset_reference",
"get_asset_by_hash",
"get_asset_detail",
"get_mtime_ns",
"get_size_and_mtime_ns",
"list_assets_page",
"list_files_recursively",
"list_tags",
"cleanup_unreferenced_assets",
"remove_tags",
"resolve_asset_for_download",
"set_asset_preview",
"update_asset_metadata",
"upload_from_temp_path",
"verify_file_unchanged",
]

View File

@ -0,0 +1,367 @@
import contextlib
import mimetypes
import os
from typing import Sequence
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
reference_exists_for_asset_id,
delete_reference_by_id,
fetch_reference_and_asset,
soft_delete_reference_by_id,
fetch_reference_asset_and_tags,
get_asset_by_hash as queries_get_asset_by_hash,
get_reference_by_id,
get_reference_with_owner_check,
list_references_page,
list_all_file_paths_by_asset_id,
list_references_by_asset_id,
set_reference_metadata,
set_reference_preview,
set_reference_tags,
update_asset_hash_and_mime,
update_reference_access_time,
update_reference_name,
update_reference_updated_at,
)
from app.assets.helpers import select_best_live_path
from app.assets.services.path_utils import compute_relative_filename
from app.assets.services.schemas import (
AssetData,
AssetDetailResult,
AssetSummaryData,
DownloadResolutionResult,
ListAssetsResult,
UserMetadata,
extract_asset_data,
extract_reference_data,
)
from app.database.db import create_session
def get_asset_detail(
reference_id: str,
owner_id: str = "",
) -> AssetDetailResult | None:
with create_session() as session:
result = fetch_reference_asset_and_tags(
session,
reference_id=reference_id,
owner_id=owner_id,
)
if not result:
return None
ref, asset, tags = result
return AssetDetailResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tags,
)
def update_asset_metadata(
reference_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: UserMetadata = None,
tag_origin: str = "manual",
owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> AssetDetailResult:
with create_session() as session:
ref = get_reference_with_owner_check(session, reference_id, owner_id)
touched = False
if name is not None and name != ref.name:
update_reference_name(session, reference_id=reference_id, name=name)
touched = True
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
new_meta: dict | None = None
if user_metadata is not None:
new_meta = dict(user_metadata)
elif computed_filename:
current_meta = ref.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
if new_meta is not None:
if computed_filename:
new_meta["filename"] = computed_filename
set_reference_metadata(
session, reference_id=reference_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_reference_tags(
session,
reference_id=reference_id,
tags=tags,
origin=tag_origin,
)
touched = True
if mime_type is not None:
updated = update_asset_hash_and_mime(
session, asset_id=ref.asset_id, mime_type=mime_type
)
if updated:
touched = True
if preview_id is not None:
set_reference_preview(
session,
reference_id=reference_id,
preview_reference_id=preview_id,
)
touched = True
if touched and user_metadata is None:
update_reference_updated_at(session, reference_id=reference_id)
result = fetch_reference_asset_and_tags(
session,
reference_id=reference_id,
owner_id=owner_id,
)
if not result:
raise RuntimeError("State changed during update")
ref, asset, tag_list = result
detail = AssetDetailResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_list,
)
session.commit()
return detail
def delete_asset_reference(
reference_id: str,
owner_id: str,
delete_content_if_orphan: bool = True,
) -> bool:
with create_session() as session:
if not delete_content_if_orphan:
# Soft delete: mark the reference as deleted but keep everything
deleted = soft_delete_reference_by_id(
session, reference_id=reference_id, owner_id=owner_id
)
session.commit()
return deleted
ref_row = get_reference_by_id(session, reference_id=reference_id)
asset_id = ref_row.asset_id if ref_row else None
file_path = ref_row.file_path if ref_row else None
deleted = delete_reference_by_id(
session, reference_id=reference_id, owner_id=owner_id
)
if not deleted:
session.commit()
return False
if not asset_id:
session.commit()
return True
still_exists = reference_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
# Orphaned asset - gather ALL file paths (including
# soft-deleted / missing refs) so their on-disk files get cleaned up.
file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
# Also include the just-deleted file path
if file_path:
file_paths.append(file_path)
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
# Delete files after commit
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def set_asset_preview(
reference_id: str,
preview_reference_id: str | None = None,
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
get_reference_with_owner_check(session, reference_id, owner_id)
set_reference_preview(
session,
reference_id=reference_id,
preview_reference_id=preview_reference_id,
)
result = fetch_reference_asset_and_tags(
session, reference_id=reference_id, owner_id=owner_id
)
if not result:
raise RuntimeError("State changed during preview update")
ref, asset, tags = result
detail = AssetDetailResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tags,
)
session.commit()
return detail
def asset_exists(asset_hash: str) -> bool:
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def get_asset_by_hash(asset_hash: str) -> AssetData | None:
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash=asset_hash)
return extract_asset_data(asset)
def list_assets_page(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> ListAssetsResult:
with create_session() as session:
refs, tag_map, total = list_references_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
items: list[AssetSummaryData] = []
for ref in refs:
items.append(
AssetSummaryData(
ref=extract_reference_data(ref),
asset=extract_asset_data(ref.asset),
tags=tag_map.get(ref.id, []),
)
)
return ListAssetsResult(items=items, total=total)
def resolve_hash_to_path(
asset_hash: str,
owner_id: str = "",
) -> DownloadResolutionResult | None:
"""Resolve a blake3 hash to an on-disk file path.
Only references visible to *owner_id* are considered (owner-less
references are always visible).
Returns a DownloadResolutionResult with abs_path, content_type, and
download_name, or None if no asset or live path is found.
"""
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash)
if not asset:
return None
refs = list_references_by_asset_id(session, asset_id=asset.id)
visible = [
r for r in refs
if r.owner_id == "" or r.owner_id == owner_id
]
abs_path = select_best_live_path(visible)
if not abs_path:
return None
display_name = os.path.basename(abs_path)
for ref in visible:
if ref.file_path == abs_path and ref.name:
display_name = ref.name
break
ctype = (
asset.mime_type
or mimetypes.guess_type(display_name)[0]
or "application/octet-stream"
)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=display_name,
)
def resolve_asset_for_download(
reference_id: str,
owner_id: str = "",
) -> DownloadResolutionResult:
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise ValueError(f"AssetReference {reference_id} not found")
ref, asset = pair
# For references with file_path, use that directly
if ref.file_path and os.path.isfile(ref.file_path):
abs_path = ref.file_path
else:
# For API-created refs without file_path, find a path from other refs
refs = list_references_by_asset_id(session, asset_id=asset.id)
abs_path = select_best_live_path(refs)
if not abs_path:
raise FileNotFoundError(
f"No live path for AssetReference {reference_id} "
f"(asset id={asset.id}, name={ref.name})"
)
# Capture ORM attributes before commit (commit expires loaded objects)
ref_name = ref.name
asset_mime = asset.mime_type
update_reference_access_time(session, reference_id=reference_id)
session.commit()
ctype = (
asset_mime
or mimetypes.guess_type(ref_name or abs_path)[0]
or "application/octet-stream"
)
download_name = ref_name or os.path.basename(abs_path)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=download_name,
)

View File

@ -0,0 +1,280 @@
from __future__ import annotations
import os
import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, TypedDict
from sqlalchemy.orm import Session
from app.assets.database.queries import (
bulk_insert_assets,
bulk_insert_references_ignore_conflicts,
bulk_insert_tags_and_meta,
delete_assets_by_ids,
get_existing_asset_ids,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_unreferenced_unhashed_asset_ids,
restore_references_by_paths,
)
from app.assets.helpers import get_utc_now
if TYPE_CHECKING:
from app.assets.services.metadata_extract import ExtractedMetadata
class SeedAssetSpec(TypedDict):
"""Spec for seeding an asset from filesystem."""
abs_path: str
size_bytes: int
mtime_ns: int
info_name: str
tags: list[str]
fname: str
metadata: ExtractedMetadata | None
hash: str | None
mime_type: str | None
class AssetRow(TypedDict):
"""Row data for inserting an Asset."""
id: str
hash: str | None
size_bytes: int
mime_type: str | None
created_at: datetime
class ReferenceRow(TypedDict):
"""Row data for inserting an AssetReference."""
id: str
asset_id: str
file_path: str
mtime_ns: int
owner_id: str
name: str
preview_id: str | None
user_metadata: dict[str, Any] | None
created_at: datetime
updated_at: datetime
last_access_time: datetime
class TagRow(TypedDict):
"""Row data for inserting a Tag."""
asset_reference_id: str
tag_name: str
origin: str
added_at: datetime
class MetadataRow(TypedDict):
"""Row data for inserting asset metadata."""
asset_reference_id: str
key: str
ordinal: int
val_str: str | None
val_num: float | None
val_bool: bool | None
val_json: dict[str, Any] | None
@dataclass
class BulkInsertResult:
"""Result of bulk asset insertion."""
inserted_refs: int
won_paths: int
lost_paths: int
def batch_insert_seed_assets(
session: Session,
specs: list[SeedAssetSpec],
owner_id: str = "",
) -> BulkInsertResult:
"""Seed assets from filesystem specs in batch.
Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
This function orchestrates:
1. Insert seed Assets (hash=NULL)
2. Claim references with ON CONFLICT DO NOTHING on file_path
3. Query to find winners (paths where our asset_id was inserted)
4. Delete Assets for losers (path already claimed by another asset)
5. Insert tags and metadata for successfully inserted references
Returns:
BulkInsertResult with inserted_refs, won_paths, lost_paths
"""
if not specs:
return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0)
current_time = get_utc_now()
asset_rows: list[AssetRow] = []
reference_rows: list[ReferenceRow] = []
path_to_asset_id: dict[str, str] = {}
asset_id_to_ref_data: dict[str, dict] = {}
absolute_path_list: list[str] = []
for spec in specs:
absolute_path = os.path.abspath(spec["abs_path"])
asset_id = str(uuid.uuid4())
reference_id = str(uuid.uuid4())
absolute_path_list.append(absolute_path)
path_to_asset_id[absolute_path] = asset_id
mime_type = spec.get("mime_type")
asset_rows.append(
{
"id": asset_id,
"hash": spec.get("hash"),
"size_bytes": spec["size_bytes"],
"mime_type": mime_type,
"created_at": current_time,
}
)
# Build user_metadata from extracted metadata or fallback to filename
extracted_metadata = spec.get("metadata")
if extracted_metadata:
user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata()
elif spec["fname"]:
user_metadata = {"filename": spec["fname"]}
else:
user_metadata = None
reference_rows.append(
{
"id": reference_id,
"asset_id": asset_id,
"file_path": absolute_path,
"mtime_ns": spec["mtime_ns"],
"owner_id": owner_id,
"name": spec["info_name"],
"preview_id": None,
"user_metadata": user_metadata,
"created_at": current_time,
"updated_at": current_time,
"last_access_time": current_time,
}
)
asset_id_to_ref_data[asset_id] = {
"reference_id": reference_id,
"tags": spec["tags"],
"filename": spec["fname"],
"extracted_metadata": extracted_metadata,
}
bulk_insert_assets(session, asset_rows)
# Filter reference rows to only those whose assets were actually inserted
# (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING)
inserted_asset_ids = get_existing_asset_ids(
session, [r["asset_id"] for r in reference_rows]
)
reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids]
bulk_insert_references_ignore_conflicts(session, reference_rows)
restore_references_by_paths(session, absolute_path_list)
winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id)
inserted_paths = {
path
for path in absolute_path_list
if path_to_asset_id[path] in inserted_asset_ids
}
losing_paths = inserted_paths - winning_paths
lost_asset_ids = [path_to_asset_id[path] for path in losing_paths]
if lost_asset_ids:
delete_assets_by_ids(session, lost_asset_ids)
if not winning_paths:
return BulkInsertResult(
inserted_refs=0,
won_paths=0,
lost_paths=len(losing_paths),
)
# Get reference IDs for winners
winning_ref_ids = [
asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"]
for path in winning_paths
]
inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids)
tag_rows: list[TagRow] = []
metadata_rows: list[MetadataRow] = []
if inserted_ref_ids:
for path in winning_paths:
asset_id = path_to_asset_id[path]
ref_data = asset_id_to_ref_data[asset_id]
ref_id = ref_data["reference_id"]
if ref_id not in inserted_ref_ids:
continue
for tag in ref_data["tags"]:
tag_rows.append(
{
"asset_reference_id": ref_id,
"tag_name": tag,
"origin": "automatic",
"added_at": current_time,
}
)
# Use extracted metadata for meta rows if available
extracted_metadata = ref_data.get("extracted_metadata")
if extracted_metadata:
metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id))
elif ref_data["filename"]:
# Fallback: just store filename
metadata_rows.append(
{
"asset_reference_id": ref_id,
"key": "filename",
"ordinal": 0,
"val_str": ref_data["filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
return BulkInsertResult(
inserted_refs=len(inserted_ref_ids),
won_paths=len(winning_paths),
lost_paths=len(losing_paths),
)
def cleanup_unreferenced_assets(session: Session) -> int:
"""Hard-delete unhashed assets with no active references.
This is a destructive operation intended for explicit cleanup.
Only deletes assets where hash=None and all references are missing.
Returns:
Number of assets deleted
"""
unreferenced_ids = get_unreferenced_unhashed_asset_ids(session)
return delete_assets_by_ids(session, unreferenced_ids)

View File

@ -0,0 +1,70 @@
import os
def get_mtime_ns(stat_result: os.stat_result) -> int:
"""Extract mtime in nanoseconds from a stat result."""
return getattr(
stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)
)
def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]:
"""Get file size in bytes and mtime in nanoseconds."""
st = os.stat(path, follow_symlinks=follow_symlinks)
return st.st_size, get_mtime_ns(st)
def verify_file_unchanged(
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
"""Check if a file is unchanged based on mtime and size.
Returns True if the file's mtime and size match the database values.
Returns False if mtime_db is None or values don't match.
size_db=None means don't check size; 0 is a valid recorded size.
"""
if mtime_db is None:
return False
actual_mtime_ns = get_mtime_ns(stat_result)
if int(mtime_db) != int(actual_mtime_ns):
return False
if size_db is not None:
return int(stat_result.st_size) == int(size_db)
return True
def is_visible(name: str) -> bool:
"""Return True if a file or directory name is visible (not hidden)."""
return not name.startswith(".")
def list_files_recursively(base_dir: str) -> list[str]:
"""Recursively list all files in a directory, following symlinks."""
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
# Track seen real directory identities to prevent circular symlink loops
seen_dirs: set[tuple[int, int]] = set()
for dirpath, subdirs, filenames in os.walk(
base_abs, topdown=True, followlinks=True
):
try:
st = os.stat(dirpath)
dir_id = (st.st_dev, st.st_ino)
except OSError:
subdirs.clear()
continue
if dir_id in seen_dirs:
subdirs.clear()
continue
seen_dirs.add(dir_id)
subdirs[:] = [d for d in subdirs if is_visible(d)]
for name in filenames:
if not is_visible(name):
continue
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out

View File

@ -0,0 +1,99 @@
import io
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import IO, Any, Callable, Iterator
import logging
try:
from blake3 import blake3
except ModuleNotFoundError:
logging.warning("WARNING: blake3 package not installed")
DEFAULT_CHUNK = 8 * 1024 * 1024
InterruptCheck = Callable[[], bool]
@dataclass
class HashCheckpoint:
"""Saved state for resuming an interrupted hash computation."""
bytes_processed: int
hasher: Any # blake3 hasher instance
mtime_ns: int = 0
file_size: int = 0
@contextmanager
def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]:
"""Yield (file_object, is_path) with appropriate setup/teardown."""
if hasattr(fp, "read"):
seekable = getattr(fp, "seekable", lambda: False)()
orig_pos = None
if seekable:
try:
orig_pos = fp.tell()
if orig_pos != 0:
fp.seek(0)
except io.UnsupportedOperation:
orig_pos = None
try:
yield fp, False
finally:
if orig_pos is not None:
fp.seek(orig_pos)
else:
with open(os.fspath(fp), "rb") as f:
yield f, True
def compute_blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
interrupt_check: InterruptCheck | None = None,
checkpoint: HashCheckpoint | None = None,
) -> tuple[str | None, HashCheckpoint | None]:
"""Compute BLAKE3 hash of a file, with optional checkpoint support.
Args:
fp: File path or file-like object
chunk_size: Size of chunks to read at a time
interrupt_check: Optional callable that returns True if the operation
should be interrupted (e.g. paused or cancelled). Must be
non-blocking so file handles are released immediately. Checked
between chunk reads.
checkpoint: Optional checkpoint to resume from (file paths only)
Returns:
Tuple of (hex_digest, None) on completion, or
(None, checkpoint) on interruption (file paths only), or
(None, None) on interruption of a file object
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
with _open_for_hashing(fp) as (f, is_path):
if checkpoint is not None and is_path:
f.seek(checkpoint.bytes_processed)
h = checkpoint.hasher
bytes_processed = checkpoint.bytes_processed
else:
h = blake3()
bytes_processed = 0
while True:
if interrupt_check is not None and interrupt_check():
if is_path:
return None, HashCheckpoint(
bytes_processed=bytes_processed,
hasher=h,
)
return None, None
chunk = f.read(chunk_size)
if not chunk:
break
h.update(chunk)
bytes_processed += len(chunk)
return h.hexdigest(), None

View File

@ -0,0 +1,463 @@
import contextlib
import logging
import mimetypes
import os
from typing import Any, Sequence
from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing
from app.assets.database.queries import (
add_tags_to_reference,
fetch_reference_and_asset,
get_asset_by_hash,
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
reference_exists,
remove_missing_tag_for_asset_id,
set_reference_metadata,
set_reference_tags,
update_asset_hash_and_mime,
upsert_asset,
upsert_reference,
validate_tags_exist,
)
from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_relative_filename,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags,
validate_path_within_base,
)
from app.assets.services.schemas import (
IngestResult,
RegisterAssetResult,
UploadResult,
UserMetadata,
extract_asset_data,
extract_reference_data,
)
from app.database.db import create_session
def _ingest_file_from_path(
abs_path: str,
asset_hash: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: UserMetadata = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> IngestResult:
locator = os.path.abspath(abs_path)
user_metadata = user_metadata or {}
asset_created = False
asset_updated = False
ref_created = False
ref_updated = False
reference_id: str | None = None
with create_session() as session:
if preview_id:
if not reference_exists(session, preview_id):
preview_id = None
asset, asset_created, asset_updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=size_bytes,
mime_type=mime_type,
)
ref_created, ref_updated = upsert_reference(
session,
asset_id=asset.id,
file_path=locator,
name=info_name or os.path.basename(locator),
mtime_ns=mtime_ns,
owner_id=owner_id,
)
# Get the reference we just created/updated
ref = get_reference_by_file_path(session, locator)
if ref:
reference_id = ref.id
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
norm = normalize_tags(list(tags))
if norm:
if require_existing_tags:
validate_tags_exist(session, norm)
add_tags_to_reference(
session,
reference_id=reference_id,
tags=norm,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
_update_metadata_with_filename(
session,
reference_id=reference_id,
file_path=ref.file_path,
current_metadata=ref.user_metadata,
user_metadata=user_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
session.commit()
return IngestResult(
asset_created=asset_created,
asset_updated=asset_updated,
ref_created=ref_created,
ref_updated=ref_updated,
reference_id=reference_id,
)
def _register_existing_asset(
asset_hash: str,
name: str,
user_metadata: UserMetadata = None,
tags: list[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> RegisterAssetResult:
user_metadata = user_metadata or {}
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"No asset with hash {asset_hash}")
if mime_type and not asset.mime_type:
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
if preview_id:
if not reference_exists(session, preview_id):
preview_id = None
ref, ref_created = get_or_create_reference(
session,
asset_id=asset.id,
owner_id=owner_id,
name=name,
preview_id=preview_id,
)
if not ref_created:
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
tag_names = get_reference_tags(session, reference_id=ref.id)
result = RegisterAssetResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created=False,
)
session.commit()
return result
new_meta = dict(user_metadata)
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
set_reference_metadata(
session,
reference_id=ref.id,
user_metadata=new_meta,
)
if tags is not None:
set_reference_tags(
session,
reference_id=ref.id,
tags=tags,
origin=tag_origin,
)
tag_names = get_reference_tags(session, reference_id=ref.id)
session.refresh(ref)
result = RegisterAssetResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created=True,
)
session.commit()
return result
def _update_metadata_with_filename(
session: Session,
reference_id: str,
file_path: str | None,
current_metadata: dict | None,
user_metadata: dict[str, Any],
) -> None:
computed_filename = compute_relative_filename(file_path) if file_path else None
current_meta = current_metadata or {}
new_meta = dict(current_meta)
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
set_reference_metadata(
session,
reference_id=reference_id,
user_metadata=new_meta,
)
def _sanitize_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
return n if n else fallback
class HashMismatchError(Exception):
pass
class DependencyMissingError(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(message)
def upload_from_temp_path(
temp_path: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
client_filename: str | None = None,
owner_id: str = "",
expected_hash: str | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult:
try:
digest, _ = hashing.compute_blake3_hash(temp_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_hash and asset_hash != expected_hash.strip().lower():
raise HashMismatchError("Uploaded file hash does not match provided hash.")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _sanitize_filename(name or client_filename, fallback=digest)
result = _register_existing_asset(
asset_hash=asset_hash,
name=display_name,
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
mime_type=mime_type,
preview_id=preview_id,
)
return UploadResult(
ref=result.ref,
asset=result.asset,
tags=result.tags,
created_new=False,
)
if not tags:
raise ValueError("tags are required for new asset uploads")
base_dir, subdirs = resolve_destination_from_tags(tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir)
content_type = mime_type or (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
ingest_result = _ingest_file_from_path(
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id,
preview_id=preview_id,
user_metadata=user_metadata or {},
tags=tags,
tag_origin="manual",
require_existing_tags=False,
)
reference_id = ingest_result.reference_id
if not reference_id:
raise RuntimeError("failed to create asset reference")
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
ref, asset = pair
tag_names = get_reference_tags(session, reference_id=ref.id)
return UploadResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def register_file_in_place(
abs_path: str,
name: str,
tags: list[str],
owner_id: str = "",
mime_type: str | None = None,
) -> UploadResult:
"""Register an already-saved file in the asset database without moving it.
Tags are derived from the filesystem path (root category + subfolder names),
merged with any caller-provided tags, matching the behavior of the scanner.
If the path is not under a known root, only the caller-provided tags are used.
"""
try:
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
except ValueError:
path_tags = []
merged_tags = normalize_tags([*path_tags, *tags])
try:
digest, _ = hashing.compute_blake3_hash(abs_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash file: {e}")
asset_hash = "blake3:" + digest
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
content_type = mime_type or (
mimetypes.guess_type(abs_path, strict=False)[0]
or "application/octet-stream"
)
ingest_result = _ingest_file_from_path(
abs_path=abs_path,
asset_hash=asset_hash,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name, fallback=digest),
owner_id=owner_id,
tags=merged_tags,
tag_origin="upload",
require_existing_tags=False,
)
reference_id = ingest_result.reference_id
if not reference_id:
raise RuntimeError("failed to create asset reference")
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
ref, asset = pair
tag_names = get_reference_tags(session, reference_id=ref.id)
return UploadResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def create_from_hash(
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult | None:
canonical = hash_str.strip().lower()
try:
result = _register_existing_asset(
asset_hash=canonical,
name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
mime_type=mime_type,
preview_id=preview_id,
)
except ValueError:
logging.warning("create_from_hash: no asset found for hash %s", canonical)
return None
return UploadResult(
ref=result.ref,
asset=result.asset,
tags=result.tags,
created_new=False,
)

View File

@ -0,0 +1,327 @@
"""Metadata extraction for asset scanning.
Tier 1: Filesystem metadata (zero parsing)
Tier 2: Safetensors header metadata (fast JSON read only)
"""
from __future__ import annotations
import json
import logging
import mimetypes
import os
import struct
from dataclasses import dataclass
from typing import Any
from utils.mime_types import init_mime_types
init_mime_types()
# Supported safetensors extensions
SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
# Maximum safetensors header size to read (8MB)
MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024
@dataclass
class ExtractedMetadata:
"""Metadata extracted from a file during scanning."""
# Tier 1: Filesystem (always available)
filename: str = ""
file_path: str = "" # Full absolute path to the file
content_length: int = 0
content_type: str | None = None
format: str = "" # file extension without dot
# Tier 2: Safetensors header (if available)
base_model: str | None = None
trained_words: list[str] | None = None
air: str | None = None # CivitAI AIR identifier
has_preview_images: bool = False
# Source provenance (populated if embedded in safetensors)
source_url: str | None = None
source_arn: str | None = None
repo_url: str | None = None
preview_url: str | None = None
source_hash: str | None = None
# HuggingFace specific
repo_id: str | None = None
revision: str | None = None
filepath: str | None = None
resolve_url: str | None = None
def to_user_metadata(self) -> dict[str, Any]:
"""Convert to user_metadata dict for AssetReference.user_metadata JSON field."""
data: dict[str, Any] = {
"filename": self.filename,
"content_length": self.content_length,
"format": self.format,
}
if self.file_path:
data["file_path"] = self.file_path
if self.content_type:
data["content_type"] = self.content_type
# Tier 2 fields
if self.base_model:
data["base_model"] = self.base_model
if self.trained_words:
data["trained_words"] = self.trained_words
if self.air:
data["air"] = self.air
if self.has_preview_images:
data["has_preview_images"] = True
# Source provenance
if self.source_url:
data["source_url"] = self.source_url
if self.source_arn:
data["source_arn"] = self.source_arn
if self.repo_url:
data["repo_url"] = self.repo_url
if self.preview_url:
data["preview_url"] = self.preview_url
if self.source_hash:
data["source_hash"] = self.source_hash
# HuggingFace
if self.repo_id:
data["repo_id"] = self.repo_id
if self.revision:
data["revision"] = self.revision
if self.filepath:
data["filepath"] = self.filepath
if self.resolve_url:
data["resolve_url"] = self.resolve_url
return data
def to_meta_rows(self, reference_id: str) -> list[dict]:
"""Convert to asset_reference_meta rows for typed/indexed querying."""
rows: list[dict] = []
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
if val:
rows.append({
"asset_reference_id": reference_id,
"key": key,
"ordinal": ordinal,
"val_str": val[:2048] if len(val) > 2048 else val,
"val_num": None,
"val_bool": None,
"val_json": None,
})
def add_num(key: str, val: int | float | None) -> None:
if val is not None:
rows.append({
"asset_reference_id": reference_id,
"key": key,
"ordinal": 0,
"val_str": None,
"val_num": val,
"val_bool": None,
"val_json": None,
})
def add_bool(key: str, val: bool | None) -> None:
if val is not None:
rows.append({
"asset_reference_id": reference_id,
"key": key,
"ordinal": 0,
"val_str": None,
"val_num": None,
"val_bool": val,
"val_json": None,
})
# Tier 1
add_str("filename", self.filename)
add_num("content_length", self.content_length)
add_str("content_type", self.content_type)
add_str("format", self.format)
# Tier 2
add_str("base_model", self.base_model)
add_str("air", self.air)
has_previews = self.has_preview_images if self.has_preview_images else None
add_bool("has_preview_images", has_previews)
# trained_words as multiple rows with ordinals
if self.trained_words:
for i, word in enumerate(self.trained_words[:100]): # limit to 100 words
add_str("trained_words", word, ordinal=i)
# Source provenance
add_str("source_url", self.source_url)
add_str("source_arn", self.source_arn)
add_str("repo_url", self.repo_url)
add_str("preview_url", self.preview_url)
add_str("source_hash", self.source_hash)
# HuggingFace
add_str("repo_id", self.repo_id)
add_str("revision", self.revision)
add_str("filepath", self.filepath)
add_str("resolve_url", self.resolve_url)
return rows
def _read_safetensors_header(
path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE
) -> dict[str, Any] | None:
"""Read only the JSON header from a safetensors file.
This is very fast - reads 8 bytes for header length, then the JSON header.
No tensor data is loaded.
Args:
path: Absolute path to safetensors file
max_size: Maximum header size to read (default 8MB)
Returns:
Parsed header dict or None if failed
"""
try:
with open(path, "rb") as f:
header_bytes = f.read(8)
if len(header_bytes) < 8:
return None
length_of_header = struct.unpack("<Q", header_bytes)[0]
if length_of_header > max_size:
return None
header_data = f.read(length_of_header)
if len(header_data) < length_of_header:
return None
return json.loads(header_data.decode("utf-8"))
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
return None
def _extract_safetensors_metadata(
header: dict[str, Any], meta: ExtractedMetadata
) -> None:
"""Extract metadata from safetensors header __metadata__ section.
Modifies meta in-place.
"""
st_meta = header.get("__metadata__", {})
if not isinstance(st_meta, dict):
return
# Common model metadata
meta.base_model = (
st_meta.get("ss_base_model_version")
or st_meta.get("modelspec.base_model")
or st_meta.get("base_model")
)
# Trained words / trigger words
trained_words = st_meta.get("ss_tag_frequency")
if trained_words and isinstance(trained_words, str):
try:
tag_freq = json.loads(trained_words)
# Extract unique tags from all datasets
all_tags: set[str] = set()
for dataset_tags in tag_freq.values():
if isinstance(dataset_tags, dict):
all_tags.update(dataset_tags.keys())
if all_tags:
meta.trained_words = sorted(all_tags)[:100]
except json.JSONDecodeError:
pass
# Direct trained_words field (some formats)
if not meta.trained_words:
tw = st_meta.get("trained_words")
if isinstance(tw, str):
try:
parsed = json.loads(tw)
if isinstance(parsed, list):
meta.trained_words = [str(x) for x in parsed]
else:
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
except json.JSONDecodeError:
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
elif isinstance(tw, list):
meta.trained_words = [str(x) for x in tw]
# CivitAI AIR
meta.air = st_meta.get("air") or st_meta.get("modelspec.air")
# Preview images (ssmd_cover_images)
cover_images = st_meta.get("ssmd_cover_images")
if cover_images:
meta.has_preview_images = True
# Source provenance fields
meta.source_url = st_meta.get("source_url")
meta.source_arn = st_meta.get("source_arn")
meta.repo_url = st_meta.get("repo_url")
meta.preview_url = st_meta.get("preview_url")
meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash")
# HuggingFace fields
meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id")
meta.revision = st_meta.get("revision") or st_meta.get("hf_revision")
meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath")
meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url")
def extract_file_metadata(
abs_path: str,
stat_result: os.stat_result | None = None,
relative_filename: str | None = None,
) -> ExtractedMetadata:
"""Extract metadata from a file using tier 1 and tier 2 methods.
Tier 1: Filesystem metadata from path and stat
Tier 2: Safetensors header parsing if applicable
Args:
abs_path: Absolute path to the file
stat_result: Optional pre-fetched stat result (saves a syscall)
relative_filename: Optional relative filename to use instead of basename
(e.g., "flux/123/model.safetensors" for model paths)
Returns:
ExtractedMetadata with all available fields populated
"""
meta = ExtractedMetadata()
# Tier 1: Filesystem metadata
meta.filename = relative_filename or os.path.basename(abs_path)
meta.file_path = abs_path
_, ext = os.path.splitext(abs_path)
meta.format = ext.lstrip(".").lower() if ext else ""
mime_type, _ = mimetypes.guess_type(abs_path)
meta.content_type = mime_type
# Size from stat
if stat_result is None:
try:
stat_result = os.stat(abs_path, follow_symlinks=True)
except OSError:
pass
if stat_result:
meta.content_length = stat_result.st_size
# Tier 2: Safetensors header (if applicable and enabled)
if ext.lower() in SAFETENSORS_EXTENSIONS:
header = _read_safetensors_header(abs_path)
if header:
try:
_extract_safetensors_metadata(header, meta)
except Exception as e:
logging.debug("Safetensors meta extract failed %s: %s", abs_path, e)
return meta

View File

@ -0,0 +1,167 @@
import os
from pathlib import Path
from typing import Literal
import folder_paths
from app.assets.helpers import normalize_tags
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build list of (folder_name, base_paths[]) for all model locations.
Includes every category registered in folder_names_and_paths,
regardless of whether its paths are under the main models_dir,
but excludes non-model entries like custom_nodes.
"""
targets: list[tuple[str, list[str]]] = []
for name, values in folder_paths.folder_names_and_paths.items():
if name in _NON_MODEL_FOLDER_NAMES:
continue
paths, _exts = values[0], values[1]
if paths:
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
if not tags:
raise ValueError("tags must not be empty")
root = tags[0].lower()
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
elif root == "input":
base_dir = os.path.abspath(folder_paths.get_input_directory())
raw_subdirs = tags[1:]
elif root == "output":
base_dir = os.path.abspath(folder_paths.get_output_directory())
raw_subdirs = tags[1:]
else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
_sep_chars = frozenset(("/", "\\", os.sep))
for i in raw_subdirs:
if i in (".", "..") or _sep_chars & set(i):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def validate_path_within_base(candidate: str, base: str) -> None:
cand_abs = Path(os.path.abspath(candidate))
base_abs = Path(os.path.abspath(base))
if not cand_abs.is_relative_to(base_abs):
raise ValueError("destination escapes base directory")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
"""
try:
root_category, rel_path = get_asset_category_and_relative_path(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "models"], str]:
"""Determine which root category a file path belongs to.
Categories:
- 'input': under folder_paths.get_input_directory()
- 'output': under folder_paths.get_output_directory()
- 'models': under any base path from get_comfy_models_folders()
Returns:
(root_category, relative_path_inside_that_root)
Raises:
ValueError: path does not belong to any known root.
"""
fp_abs = os.path.abspath(file_path)
def _check_is_within(child: str, parent: str) -> bool:
return Path(child).is_relative_to(parent)
def _compute_relative(child: str, parent: str) -> str:
# Normalize relative path, stripping any leading ".." components
# by anchoring to root (os.sep) then computing relpath back from it.
return os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _check_is_within(fp_abs, input_base):
return "input", _compute_relative(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _check_is_within(fp_abs, output_base):
return "output", _compute_relative(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _check_is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(
f"Path is not within input, output, or configured model bases: {file_path}"
)
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return (name, tags) derived from a filesystem path.
- name: base filename with extension
- tags: [root_category] + parent folder names in order
Raises:
ValueError: path does not belong to any known root.
"""
root_category, some_path = get_asset_category_and_relative_path(file_path)
p = Path(some_path)
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))

View File

@ -0,0 +1,113 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NamedTuple
from app.assets.database.models import Asset, AssetReference
UserMetadata = dict[str, Any] | None
@dataclass(frozen=True)
class AssetData:
hash: str | None
size_bytes: int | None
mime_type: str | None
@dataclass(frozen=True)
class ReferenceData:
"""Data transfer object for AssetReference."""
id: str
name: str
file_path: str | None
user_metadata: UserMetadata
preview_id: str | None
created_at: datetime
updated_at: datetime
system_metadata: dict[str, Any] | None = None
job_id: str | None = None
last_access_time: datetime | None = None
@dataclass(frozen=True)
class AssetDetailResult:
ref: ReferenceData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class RegisterAssetResult:
ref: ReferenceData
asset: AssetData
tags: list[str]
created: bool
@dataclass(frozen=True)
class IngestResult:
asset_created: bool
asset_updated: bool
ref_created: bool
ref_updated: bool
reference_id: str | None
class TagUsage(NamedTuple):
name: str
tag_type: str
count: int
@dataclass(frozen=True)
class AssetSummaryData:
ref: ReferenceData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class ListAssetsResult:
items: list[AssetSummaryData]
total: int
@dataclass(frozen=True)
class DownloadResolutionResult:
abs_path: str
content_type: str
download_name: str
@dataclass(frozen=True)
class UploadResult:
ref: ReferenceData
asset: AssetData
tags: list[str]
created_new: bool
def extract_reference_data(ref: AssetReference) -> ReferenceData:
return ReferenceData(
id=ref.id,
name=ref.name,
file_path=ref.file_path,
user_metadata=ref.user_metadata,
preview_id=ref.preview_id,
system_metadata=ref.system_metadata,
job_id=ref.job_id,
created_at=ref.created_at,
updated_at=ref.updated_at,
last_access_time=ref.last_access_time,
)
def extract_asset_data(asset: Asset | None) -> AssetData | None:
if asset is None:
return None
return AssetData(
hash=asset.hash,
size_bytes=asset.size_bytes,
mime_type=asset.mime_type,
)

View File

@ -0,0 +1,98 @@
from typing import Sequence
from app.assets.database.queries import (
AddTagsResult,
RemoveTagsResult,
add_tags_to_reference,
get_reference_with_owner_check,
list_tags_with_usage,
remove_tags_from_reference,
)
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
from app.assets.services.schemas import TagUsage
from app.database.db import create_session
def apply_tags(
reference_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> AddTagsResult:
with create_session() as session:
ref_row = get_reference_with_owner_check(session, reference_id, owner_id)
result = add_tags_to_reference(
session,
reference_id=reference_id,
tags=tags,
origin=origin,
create_if_missing=True,
reference_row=ref_row,
)
session.commit()
return result
def remove_tags(
reference_id: str,
tags: list[str],
owner_id: str = "",
) -> RemoveTagsResult:
with create_session() as session:
get_reference_with_owner_check(session, reference_id, owner_id)
result = remove_tags_from_reference(
session,
reference_id=reference_id,
tags=tags,
)
session.commit()
return result
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> tuple[list[TagUsage], int]:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
def list_tag_histogram(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
with create_session() as session:
return list_tag_counts_for_filtered_assets(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
)

View File

@ -3,6 +3,7 @@ import os
import shutil import shutil
from app.logger import log_startup_warning from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message from utils.install_util import get_missing_requirements_message
from filelock import FileLock, Timeout
from comfy.cli_args import args from comfy.cli_args import args
_DB_AVAILABLE = False _DB_AVAILABLE = False
@ -14,8 +15,12 @@ try:
from alembic.config import Config from alembic.config import Config
from alembic.runtime.migration import MigrationContext from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory from alembic.script import ScriptDirectory
from sqlalchemy import create_engine from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.database.models import Base
import app.assets.database.models # noqa: F401 — register models with Base.metadata
_DB_AVAILABLE = True _DB_AVAILABLE = True
except ImportError as e: except ImportError as e:
@ -65,9 +70,69 @@ def get_db_path():
raise ValueError(f"Unsupported database URL '{url}'.") raise ValueError(f"Unsupported database URL '{url}'.")
_db_lock = None
def _acquire_file_lock(db_path):
"""Acquire an OS-level file lock to prevent multi-process access.
Uses filelock for cross-platform support (macOS, Linux, Windows).
The OS automatically releases the lock when the process exits, even on crashes.
"""
global _db_lock
lock_path = db_path + ".lock"
_db_lock = FileLock(lock_path)
try:
_db_lock.acquire(timeout=0)
except Timeout:
raise RuntimeError(
f"Could not acquire lock on database '{db_path}'. "
"Another ComfyUI process may already be using it. "
"Use --database-url to specify a separate database file."
)
def _is_memory_db(db_url):
"""Check if the database URL refers to an in-memory SQLite database."""
return db_url in ("sqlite:///:memory:", "sqlite://")
def init_db(): def init_db():
db_url = args.database_url db_url = args.database_url
logging.debug(f"Database URL: {db_url}") logging.debug(f"Database URL: {db_url}")
if _is_memory_db(db_url):
_init_memory_db(db_url)
else:
_init_file_db(db_url)
def _init_memory_db(db_url):
"""Initialize an in-memory SQLite database using metadata.create_all.
Alembic migrations don't work with in-memory SQLite because each
connection gets its own separate database tables created by Alembic's
internal connection are lost immediately.
"""
engine = create_engine(
db_url,
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
Base.metadata.create_all(engine)
global Session
Session = sessionmaker(bind=engine)
def _init_file_db(db_url):
"""Initialize a file-backed SQLite database using Alembic migrations."""
db_path = get_db_path() db_path = get_db_path()
db_exists = os.path.exists(db_path) db_exists = os.path.exists(db_path)
@ -75,6 +140,14 @@ def init_db():
# Check if we need to upgrade # Check if we need to upgrade
engine = create_engine(db_url) engine = create_engine(db_url)
# Enable foreign key enforcement for SQLite
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
conn = engine.connect() conn = engine.connect()
context = MigrationContext.configure(conn) context = MigrationContext.configure(conn)
@ -104,6 +177,12 @@ def init_db():
logging.exception("Error upgrading database: ") logging.exception("Error upgrading database: ")
raise e raise e
# Acquire an OS-level file lock after migrations are complete.
# Alembic uses its own connection, so we must wait until it's done
# before locking — otherwise our own lock blocks the migration.
conn.close()
_acquire_file_lock(db_path)
global Session global Session
Session = sessionmaker(bind=engine) Session = sessionmaker(bind=engine)

View File

@ -1,9 +1,18 @@
from typing import Any from typing import Any
from datetime import datetime from datetime import datetime
from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
NAMING_CONVENTION = {
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass metadata = MetaData(naming_convention=NAMING_CONVENTION)
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys() fields = obj.__table__.columns.keys()

View File

@ -6,6 +6,7 @@ import uuid
import glob import glob
import shutil import shutil
import logging import logging
import tempfile
from aiohttp import web from aiohttp import web
from urllib import parse from urllib import parse
from comfy.cli_args import args from comfy.cli_args import args
@ -377,8 +378,15 @@ class UserManager():
try: try:
body = await request.read() body = await request.read()
with open(path, "wb") as f: dir_name = os.path.dirname(path)
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
try:
with os.fdopen(fd, "wb") as f:
f.write(body) f.write(body)
os.replace(tmp_path, path)
except:
os.unlink(tmp_path)
raise
except OSError as e: except OSError as e:
logging.warning(f"Error saving file '{path}': {e}") logging.warning(f"Error saving file '{path}': {e}")
return web.Response( return web.Response(

View File

@ -27,6 +27,7 @@ class AudioEncoderModel():
self.model.eval() self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device) self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000 self.model_sample_rate = 16000
comfy.model_management.archive_model_dtypes(self.model)
def load_sd(self, sd): def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic()) return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())

View File

@ -83,6 +83,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.") fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
@ -147,6 +149,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@ -232,7 +235,7 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db") os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
) )
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.") parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.") parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
if comfy.options.args_parsing: if comfy.options.args_parsing:
args = parser.parse_args() args = parser.parse_args()
@ -260,4 +263,6 @@ else:
args.fast = set(args.fast) args.fast = set(args.fast)
def enables_dynamic_vram(): def enables_dynamic_vram():
if args.enable_dynamic_vram:
return True
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu

View File

@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget. """COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4 Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987""" https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
gradient_stops: NotRequired[list[list[float]]] gradient_stops: NotRequired[list[dict]]
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``).""" """Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
class HiddenInputTypeDict(TypedDict): class HiddenInputTypeDict(TypedDict):

View File

@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
output_block[i:i + slice_size].copy_(block) output_block[i:i + slice_size].copy_(block)
return output_fp4, to_blocked(output_block, flatten=False) return output_fp4, to_blocked(output_block, flatten=False)
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
def roundup(x_val, multiple):
return ((x_val + multiple - 1) // multiple) * multiple
if pad_32x:
rows, cols = x.shape
padded_rows = roundup(rows, 32)
padded_cols = roundup(cols, 32)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
F8_E4M3_MAX = 448.0
E8M0_BIAS = 127
BLOCK_SIZE = 32
rows, cols = x.shape
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
# E8M0 block scales (power-of-2 exponents)
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
block_scales_e8m0 = exp_biased.to(torch.uint8)
zero_mask = (max_abs == 0)
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
# Scale per-block then stochastic round
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)

View File

@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor * m_mult return tensor * m_mult
else: else:
for d in modulation_dims: for d in modulation_dims:
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]] tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
if m_add is not None: if m_add is not None:
tensor[:, d[0]:d[1]] += m_add[:, d[2]] tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
return tensor return tensor
@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module):
del txt_k, img_k del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2) v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v del txt_v, img_v
extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
# run actual attention # run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v del q, k, v
if "attn1_output_patch" in transformer_patches: if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"] patch = transformer_patches["attn1_output_patch"]
for p in patch: for p in patch:
attn = p(attn, extra_options) attn = p(attn, extra_options)
@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module):
del qkv del qkv
q, k = self.norm(q, k, v) q, k = self.norm(q, k, v)
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
# compute attention # compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options) attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v del q, k, v

View File

@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def _apply_rope1(x: Tensor, freqs_cis: Tensor): def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2) x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
freqs_cis = freqs_cis[:, :, :x_.shape[2]]
x_out = freqs_cis[..., 0] * x_[..., 0] x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1]) x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])

View File

@ -44,6 +44,22 @@ class FluxParams:
txt_norm: bool = False txt_norm: bool = False
def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0
for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)
if current < length:
result.append((current, length))
return result
class Flux(nn.Module): class Flux(nn.Module):
""" """
Transformer model for flow matching on sequences. Transformer model for flow matching on sequences.
@ -138,6 +154,7 @@ class Flux(nn.Module):
y: Tensor, y: Tensor,
guidance: Tensor = None, guidance: Tensor = None,
control = None, control = None,
timestep_zero_index=None,
transformer_options={}, transformer_options={},
attn_mask: Tensor = None, attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
@ -164,13 +181,9 @@ class Flux(nn.Module):
txt = self.txt_norm(txt) txt = self.txt_norm(txt)
txt = self.txt_in(txt) txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches: if "post_input" in patches:
for p in patches["post_input"]: for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids}) out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
img = out["img"] img = out["img"]
txt = out["txt"] txt = out["txt"]
img_ids = out["img_ids"] img_ids = out["img_ids"]
@ -182,6 +195,24 @@ class Flux(nn.Module):
else: else:
pe = None pe = None
vec_orig = vec
txt_vec = vec
extra_kwargs = {}
if timestep_zero_index is not None:
modulation_dims = []
batch = vec.shape[0] // 2
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
invert = invert_slices(timestep_zero_index, img.shape[1])
for s in invert:
modulation_dims.append((s[0], s[1], 0))
for s in timestep_zero_index:
modulation_dims.append((s[0], s[1], 1))
extra_kwargs["modulation_dims_img"] = modulation_dims
txt_vec = vec[:batch]
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double" transformer_options["block_type"] = "double"
@ -195,7 +226,8 @@ class Flux(nn.Module):
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask"), attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options")) transformer_options=args.get("transformer_options"),
**extra_kwargs)
return out return out
out = blocks_replace[("double_block", i)]({"img": img, out = blocks_replace[("double_block", i)]({"img": img,
@ -213,7 +245,8 @@ class Flux(nn.Module):
vec=vec, vec=vec,
pe=pe, pe=pe,
attn_mask=attn_mask, attn_mask=attn_mask,
transformer_options=transformer_options) transformer_options=transformer_options,
**extra_kwargs)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
@ -230,6 +263,12 @@ class Flux(nn.Module):
if self.params.global_modulation: if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig) vec, _ = self.single_stream_modulation(vec_orig)
extra_kwargs = {}
if timestep_zero_index is not None:
lambda a: 0 if a == 0 else a + txt.shape[1]
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
extra_kwargs["modulation_dims"] = modulation_dims_combined
transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single" transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
@ -242,7 +281,8 @@ class Flux(nn.Module):
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask"), attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options")) transformer_options=args.get("transformer_options"),
**extra_kwargs)
return out return out
out = blocks_replace[("single_block", i)]({"img": img, out = blocks_replace[("single_block", i)]({"img": img,
@ -253,7 +293,7 @@ class Flux(nn.Module):
{"original_block": block_wrap}) {"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
if control is not None: # Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_o = control.get("output")
@ -264,7 +304,11 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) extra_kwargs = {}
if timestep_zero_index is not None:
extra_kwargs["modulation_dims"] = modulation_dims
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
return img return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@ -312,13 +356,16 @@ class Flux(nn.Module):
w_len = ((w_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options) img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1] img_tokens = img.shape[1]
timestep_zero_index = None
if ref_latents is not None: if ref_latents is not None:
ref_num_tokens = []
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
timestep_zero = ref_latents_method == "index_timestep_zero"
for ref in ref_latents: for ref in ref_latents:
if ref_latents_method == "index": if ref_latents_method in ("index", "index_timestep_zero"):
index += self.params.ref_index_scale index += self.params.ref_index_scale
h_offset = 0 h_offset = 0
w_offset = 0 w_offset = 0
@ -342,6 +389,13 @@ class Flux(nn.Module):
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
img = torch.cat([img, kontext], dim=1) img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
@ -349,6 +403,6 @@ class Flux(nn.Module):
for i in self.params.txt_ids_dims: for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens] out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

View File

@ -343,6 +343,7 @@ class CrossAttention(nn.Module):
k.reshape(b, s2, self.num_heads * self.head_dim), k.reshape(b, s2, self.num_heads * self.head_dim),
v, v,
heads=self.num_heads, heads=self.num_heads,
low_precision_attention=False,
) )
out = self.out_proj(x) out = self.out_proj(x)
@ -412,6 +413,7 @@ class Attention(nn.Module):
key.reshape(B, N, self.num_heads * self.head_dim), key.reshape(B, N, self.num_heads * self.head_dim),
value, value,
heads=self.num_heads, heads=self.num_heads,
low_precision_attention=False,
) )
x = self.out_proj(x) x = self.out_proj(x)

View File

@ -11,6 +11,7 @@ from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops import comfy.ops
import comfy.model_management
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -536,7 +537,7 @@ class Decoder(nn.Module):
mark_conv3d_ended(self.conv_out) mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal) sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0: if sample is not None and sample.shape[2] > 0:
output.append(sample) output.append(sample.to(comfy.model_management.intermediate_device()))
return return
up_block = self.up_blocks[idx] up_block = self.up_blocks[idx]

View File

@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v) r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2 del s2
break break
except model_management.OOM_EXCEPTION as e: except Exception as e:
model_management.raise_non_oom(e)
if first_op_done == False: if first_op_done == False:
model_management.soft_empty_cache(True) model_management.soft_empty_cache(True)
if cleared_cache == False: if cleared_cache == False:

View File

@ -258,7 +258,8 @@ def slice_attention(q, k, v):
r1[:, :, i:end] = torch.bmm(v, s2) r1[:, :, i:end] = torch.bmm(v, s2)
del s2 del s2
break break
except model_management.OOM_EXCEPTION as e: except Exception as e:
model_management.raise_non_oom(e)
model_management.soft_empty_cache(True) model_management.soft_empty_cache(True)
steps *= 2 steps *= 2
if steps > 128: if steps > 128:
@ -314,7 +315,8 @@ def pytorch_attention(q, k, v):
try: try:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False) out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape) out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION: except Exception as e:
model_management.raise_non_oom(e)
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention") logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True oom_fallback = True
if oom_fallback: if oom_fallback:

View File

@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking(
try: try:
attn_probs = attn_scores.softmax(dim=-1) attn_probs = attn_scores.softmax(dim=-1)
del attn_scores del attn_scores
except model_management.OOM_EXCEPTION: except Exception as e:
model_management.raise_non_oom(e)
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead") logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores) torch.exp(attn_scores, out=attn_scores)

View File

@ -149,6 +149,9 @@ class Attention(nn.Module):
seq_img = hidden_states.shape[1] seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1] seq_txt = encoder_hidden_states.shape[1]
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
# Project and reshape to BHND format (batch, heads, seq, dim) # Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous() img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
@ -167,15 +170,22 @@ class Attention(nn.Module):
joint_key = torch.cat([txt_key, img_key], dim=2) joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2) joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
if encoder_hidden_states_mask is not None: if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device) attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else: else:
attn_mask = None attn_mask = None
extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads, joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attn_mask, transformer_options=transformer_options, attn_mask, transformer_options=transformer_options,
skip_reshape=True) skip_reshape=True)
@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module):
timestep_zero_index = None timestep_zero_index = None
if ref_latents is not None: if ref_latents is not None:
ref_num_tokens = []
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1) hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero: if timestep_zero:
if index > 0: if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0) timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds timestep_zero_index = num_embeds
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2)) txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3) txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states) hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states) encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {}) patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
img_ids = out["img_ids"]
txt_ids = out["txt_ids"]
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
transformer_options["total_blocks"] = len(self.transformer_blocks) transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double" transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks): for i, block in enumerate(self.transformer_blocks):

View File

@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk: for k in sdk:
if k.endswith(".weight"): if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
if tp > 0 and not k.startswith("clip_"):
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}" text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False clip_l_present = False

View File

@ -1,9 +1,68 @@
import math import math
import ctypes
import threading
import dataclasses
import torch import torch
from typing import NamedTuple from typing import NamedTuple
from comfy.quant_ops import QuantizedTensor from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple):
file_ref: object
thread_id: int
offset: int
size: int
def read_tensor_file_slice_into(tensor, destination):
if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
return False
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
return False
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
return True
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
if info is None:
return False
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size):
return False
if info.size == 0:
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
try:
file_obj.seek(info.offset)
done = 0
while done < info.size:
try:
n = file_obj.readinto(view[done:])
except OSError:
return False
if n <= 0:
return False
done += n
return True
finally:
view.release()
class TensorGeometry(NamedTuple): class TensorGeometry(NamedTuple):
shape: any shape: any
dtype: torch.dtype dtype: torch.dtype

View File

@ -1,4 +1,5 @@
import json import json
import comfy.memory_management
import comfy.supported_models import comfy.supported_models
import comfy.supported_models_base import comfy.supported_models_base
import comfy.utils import comfy.utils
@ -1118,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
new[:old_weight.shape[0]] = old_weight new[:old_weight.shape[0]] = old_weight
old_weight = new old_weight = new
if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled:
old_weight = old_weight.clone()
w = old_weight.narrow(offset[0], offset[1], offset[2]) w = old_weight.narrow(offset[0], offset[1], offset[2])
else: else:
if comfy.memory_management.aimdo_enabled:
weight = weight.clone()
old_weight = weight old_weight = weight
w = weight w = weight
w[:] = fun(weight) w[:] = fun(weight)

View File

@ -270,6 +270,23 @@ try:
except: except:
OOM_EXCEPTION = Exception OOM_EXCEPTION = Exception
try:
ACCELERATOR_ERROR = torch.AcceleratorError
except AttributeError:
ACCELERATOR_ERROR = RuntimeError
def is_oom(e):
if isinstance(e, OOM_EXCEPTION):
return True
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
discard_cuda_async_error()
return True
return False
def raise_non_oom(e):
if not is_oom(e):
raise e
XFORMERS_VERSION = "" XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True XFORMERS_ENABLED_VAE = True
if args.disable_xformers: if args.disable_xformers:
@ -383,7 +400,7 @@ try:
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0): if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1200", "gfx1201"]): if any((a in arch) for a in ["gfx1200", "gfx1201"]):
@ -488,6 +505,28 @@ def module_size(module):
module_mem += t.nbytes module_mem += t.nbytes
return module_mem return module_mem
def module_mmap_residency(module, free=False):
mmap_touched_mem = 0
module_mem = 0
bounced_mmaps = set()
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nbytes
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
continue
mmap_touched_mem += t.nbytes
if not free:
continue
storage._comfy_tensor_mmap_touched = False
mmap_obj = storage._comfy_tensor_mmap_refs[0]
if mmap_obj in bounced_mmaps:
continue
mmap_obj.bounce()
bounced_mmaps.add(mmap_obj)
return mmap_touched_mem, module_mem
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model):
self._set_model(model) self._set_model(model)
@ -502,6 +541,7 @@ class LoadedModel:
if model.parent is not None: if model.parent is not None:
self._parent_model = weakref.ref(model.parent) self._parent_model = weakref.ref(model.parent)
self._patcher_finalizer = weakref.finalize(model, self._switch_parent) self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
self._patcher_finalizer.atexit = False
def _switch_parent(self): def _switch_parent(self):
model = self._parent_model() model = self._parent_model()
@ -515,6 +555,9 @@ class LoadedModel:
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)
def model_loaded_memory(self): def model_loaded_memory(self):
return self.model.loaded_size() return self.model.loaded_size()
@ -545,6 +588,7 @@ class LoadedModel:
self.real_model = weakref.ref(real_model) self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models) self.model_finalizer = weakref.finalize(real_model, cleanup_models)
self.model_finalizer.atexit = False
return real_model return real_model
def should_reload_model(self, force_patch_weights=False): def should_reload_model(self, force_patch_weights=False):
@ -616,7 +660,7 @@ def extra_reserved_memory():
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc() cleanup_models_gc()
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
@ -629,13 +673,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False shift_model.currently_used = False
for x in sorted(can_unload): can_unload_sorted = sorted(can_unload)
for x in can_unload_sorted:
i = x[-1] i = x[-1]
memory_to_free = 1e32 memory_to_free = 1e32
ram_to_free = 1e32 pins_to_free = 1e32
if not DISABLE_SMART_MEMORY: if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device) memory_to_free = memory_required - get_free_memory(device)
ram_to_free = ram_required - get_free_ram() pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic: if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models #don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand. #as that works on-demand.
@ -644,9 +689,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i) unloaded_model.append(i)
if ram_to_free > 0: if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i)) unloaded_models.append(current_loaded_models.pop(i))
@ -712,17 +766,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
total_memory_required = {} total_memory_required = {}
total_pins_required = {}
total_ram_required = {} total_ram_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) device = loaded_model.device
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
#want to do. resident_memory, model_memory = loaded_model.model_mmap_residency()
#FIXME: This should subtract off the to_load current pin consumption. pinned_memory = loaded_model.model.pinned_memory_size()
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2 #FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device]) free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic,
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
@ -939,7 +1003,7 @@ def text_encoder_offload_device():
def text_encoder_device(): def text_encoder_device():
if args.gpu_only: if args.gpu_only:
return get_torch_device() return get_torch_device()
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED): elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM, VRAMState.SHARED) or comfy.memory_management.aimdo_enabled:
if should_use_fp16(prioritize_performance=False): if should_use_fp16(prioritize_performance=False):
return get_torch_device() return get_torch_device()
else: else:
@ -988,6 +1052,12 @@ def intermediate_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def intermediate_dtype():
if args.fp16_intermediates:
return torch.float16
else:
return torch.float32
def vae_device(): def vae_device():
if args.cpu_vae: if args.cpu_vae:
return torch.device("cpu") return torch.device("cpu")
@ -1148,6 +1218,7 @@ def reset_cast_buffers():
LARGEST_CASTED_WEIGHT = (None, 0) LARGEST_CASTED_WEIGHT = (None, 0)
for offload_stream in STREAM_CAST_BUFFERS: for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize() offload_stream.synchronize()
synchronize()
STREAM_CAST_BUFFERS.clear() STREAM_CAST_BUFFERS.clear()
soft_empty_cache() soft_empty_cache()
@ -1207,6 +1278,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
dest_view = dest_views.pop(0) dest_view = dest_views.pop(0)
if tensor is None: if tensor is None:
continue continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"):
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking) dest_view.copy_(tensor, non_blocking=non_blocking)
@ -1262,7 +1338,7 @@ def discard_cuda_async_error():
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b _ = a + b
synchronize() synchronize()
except torch.AcceleratorError: except RuntimeError:
#Dump it! We already know about it from the synchronous return #Dump it! We already know about it from the synchronous return
pass pass
@ -1644,6 +1720,19 @@ def supports_nvfp4_compute(device=None):
return True return True
def supports_mxfp8_compute(device=None):
if not is_nvidia():
return False
if torch_version_numeric < (2, 10):
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support(): def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older # TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7): if torch_version_numeric < (2, 7):

View File

@ -297,6 +297,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model) self.size = comfy.model_management.module_size(self.model)
return self.size return self.size
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def get_ram_usage(self): def get_ram_usage(self):
return self.model_size() return self.model_size()
@ -599,6 +602,27 @@ class ModelPatcher:
return models return models
def model_patches_call_function(self, function_name="cleanup", arguments={}):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], function_name):
getattr(patch_list[i], function_name)(**arguments)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], function_name):
getattr(patch_list[k], function_name)(**arguments)
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, function_name):
getattr(wrap_func, function_name)(**arguments)
def model_dtype(self): def model_dtype(self):
if hasattr(self.model, "get_dtype"): if hasattr(self.model, "get_dtype"):
return self.model.get_dtype() return self.model.get_dtype()
@ -715,8 +739,8 @@ class ModelPatcher:
default = True # default random weights in non leaf modules default = True # default random weights in non leaf modules
break break
if default and default_device is not None: if default and default_device is not None:
for param in params.values(): for param_name, param in params.items():
param.data = param.data.to(device=default_device) param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None))
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0): if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m) module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem module_offload_mem = module_mem
@ -1042,6 +1066,10 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used return self.model.model_loaded_weight_memory - current_used
def pinned_memory_size(self):
# Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload):
pass pass
@ -1062,6 +1090,7 @@ class ModelPatcher:
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype) return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
def cleanup(self): def cleanup(self):
self.model_patches_call_function(function_name="cleanup")
self.clean_hooks() self.clean_hooks()
if hasattr(self.model, "current_patcher"): if hasattr(self.model, "current_patcher"):
self.model.current_patcher = None self.model.current_patcher = None
@ -1631,6 +1660,16 @@ class ModelPatcherDynamic(ModelPatcher):
return freed return freed
def pinned_memory_size(self):
total = 0
loading = self._load_list(for_dynamic=True)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device) loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading: for x in loading:

View File

@ -306,10 +306,40 @@ class CastWeightBiasOp:
bias_function = [] bias_function = []
class disable_weight_init: class disable_weight_init:
@staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape,
bias_shape=None):
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k, v in state_dict.items():
key = k[prefix_len:]
if key == "weight":
if not assign_to_params_buffers:
v = v.clone()
module.weight = torch.nn.Parameter(v, requires_grad=False)
elif bias_shape is not None and key == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
module.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
if module.weight is None:
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
missing_keys.append(prefix + "weight")
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
missing_keys.append(prefix + "bias")
class Linear(torch.nn.Linear, CastWeightBiasOp): class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: # don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
super().__init__(in_features, out_features, bias, device, dtype) super().__init__(in_features, out_features, bias, device, dtype)
return return
@ -330,32 +360,21 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs): strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs) missing_keys, unexpected_keys, error_msgs)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) disable_weight_init._lazy_load_from_state_dict(
prefix_len = len(prefix) self,
for k,v in state_dict.items(): state_dict,
if k[prefix_len:] == "weight": prefix,
if not assign_to_params_buffers: local_metadata,
v = v.clone() missing_keys,
self.weight = torch.nn.Parameter(v, requires_grad=False) unexpected_keys,
elif k[prefix_len:] == "bias" and v is not None: weight_shape=(self.in_features, self.out_features),
if not assign_to_params_buffers: bias_shape=(self.out_features,),
v = v.clone() )
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
#Reconcile default construction of the weight if its missing.
if self.weight is None:
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")
if self.bias is None and self.comfy_need_lazy_init_bias:
v = torch.zeros(self.out_features,)
self.bias = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"bias")
def reset_parameters(self): def reset_parameters(self):
@ -547,6 +566,53 @@ class disable_weight_init:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp): class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
_freeze=False, device=None, dtype=None):
# don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight,
_freeze, device, dtype)
return
torch.nn.Module.__init__(self)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# Keep shape/dtype visible for module introspection without reserving storage.
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
self.weight = torch.nn.Parameter(
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
requires_grad=False,
)
self.bias = None
self.weight_comfy_model_dtype = dtype
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.num_embeddings, self.embedding_dim),
)
def reset_parameters(self): def reset_parameters(self):
self.bias = None self.bias = None
return None return None
@ -710,6 +776,71 @@ from .quant_ops import (
) )
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""
@staticmethod
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input (same as inference path)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
q_input = inp
w = weight.detach() if weight.requires_grad else weight
b = bias.detach() if bias is not None and bias.requires_grad else bias
output = torch.nn.functional.linear(q_input, w, b)
# Restore original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
ctx.save_for_backward(input_float, weight)
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
weight_f = weight.dequantize().to(compute_dtype)
else:
weight_f = weight.to(compute_dtype)
# grad_input = grad_output @ weight
grad_input = torch.mm(grad_2d, weight_f)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)
# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)
return grad_input, grad_weight, grad_bias, None, None, None
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast): class MixedPrecisionOps(manual_cast):
_quant_config = quant_config _quant_config = quant_config
@ -801,6 +932,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
orig_shape=(self.out_features, self.in_features), orig_shape=(self.out_features, self.in_features),
) )
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4": elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
@ -888,10 +1035,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
#If cast needs to apply lora, it should be done in the compute dtype #If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype compute_dtype = input.dtype
if (getattr(self, 'layout_type', None) is not None and _use_quantized = (
getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0): len(self.weight_function) == 0 and len(self.bias_function) == 0
)
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=True
)
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
output = QuantLinearFunc.apply(
input, weight, bias, self.layout_type, scale, compute_dtype
)
uncast_bias_weight(self, weight, bias, offload_stream)
return output
# Inference path (unchanged)
if _use_quantized:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@ -939,7 +1113,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
for key, param in self._parameters.items(): for key, param in self._parameters.items():
if param is None: if param is None:
continue continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) p = fn(param)
if p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items(): for key, buf in self._buffers.items():
if buf is not None: if buf is not None:
self._buffers[key] = fn(buf) self._buffers[key] = fn(buf)
@ -950,12 +1127,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations") logging.info("Using mixed precision operations")
disabled = set() disabled = set()
if not nvfp4_compute: if not nvfp4_compute:
disabled.add("nvfp4") disabled.add("nvfp4")
if not mxfp8_compute:
disabled.add("mxfp8")
if not fp8_compute: if not fp8_compute:
disabled.add("float8_e4m3fn") disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2") disabled.add("float8_e5m2")

View File

@ -1,6 +1,7 @@
import torch
import comfy.model_management import comfy.model_management
import comfy.memory_management import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
from comfy.cli_args import args from comfy.cli_args import args
@ -12,18 +13,31 @@ def pin_memory(module):
return return
#FIXME: This is a RAM cache trigger event #FIXME: This is a RAM cache trigger event
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
pin = torch.empty((size,), dtype=torch.uint8)
if comfy.model_management.pin_memory(pin): if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
module._pin = pin
else:
module.pin_failed = True module.pin_failed = True
return False return False
try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
except RuntimeError:
module.pin_failed = True
return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
module._pin_hostbuf = hostbuf
comfy.model_management.TOTAL_PINNED_MEMORY += size
return True return True
def unpin_memory(module): def unpin_memory(module):
if get_pin(module) is None: if get_pin(module) is None:
return 0 return 0
size = module._pin.numel() * module._pin.element_size() size = module._pin.numel() * module._pin.element_size()
comfy.model_management.unpin_memory(module._pin)
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
del module._pin del module._pin
del module._pin_hostbuf
return size return size

View File

@ -43,6 +43,18 @@ except ImportError as e:
def get_layout_class(name): def get_layout_class(name):
return None return None
_CK_MXFP8_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
_CK_MXFP8_AVAILABLE = True
except ImportError:
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
if not _CK_MXFP8_AVAILABLE:
class _CKMxfp8Layout:
pass
import comfy.float import comfy.float
# ============================================================================== # ==============================================================================
@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params return qdata, params
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
params = cls.Params(
scale=block_scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
)
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout): class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod @classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
QUANT_ALGOS = { QUANT_ALGOS = {
"float8_e4m3fn": { "float8_e4m3fn": {
@ -157,6 +196,14 @@ QUANT_ALGOS = {
}, },
} }
if _CK_MXFP8_AVAILABLE:
QUANT_ALGOS["mxfp8"] = {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
"group_size": 32,
}
# ============================================================================== # ==============================================================================
# Re-exports for backward compatibility # Re-exports for backward compatibility

View File

@ -871,13 +871,16 @@ class VAE:
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value) pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels return pixels
def vae_output_dtype(self):
return model_management.intermediate_dtype()
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
output = self.process_output( output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@ -887,16 +890,16 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32): def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3: if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
else: else:
og_shape = samples.shape og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@ -905,7 +908,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@ -914,7 +917,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1: if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
out_channels = self.latent_channels out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio upscale_amount = 1 / self.downscale_ratio
else: else:
@ -923,7 +926,7 @@ class VAE:
tile_x = tile_x // extra_channel_size tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1: if self.latent_dim == 1:
@ -932,7 +935,7 @@ class VAE:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1) return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}): def decode(self, samples_in, vae_options={}):
@ -950,11 +953,12 @@ class VAE:
for x in range(0, samples_in.shape[0], batch_number): for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
if pixel_samples is None: if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number] = out pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION: except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the #NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block. #exception and the exception itself refs them all until we get out of this except block.
@ -1024,12 +1028,13 @@ class VAE:
samples = None samples = None
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device) pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float() out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None: if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out samples[x:x + batch_number] = out
except model_management.OOM_EXCEPTION: except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.") logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the #NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block. #exception and the exception itself refs them all until we get out of this except block.

View File

@ -20,6 +20,8 @@
import torch import torch
import math import math
import struct import struct
import ctypes
import os
import comfy.memory_management import comfy.memory_management
import safetensors.torch import safetensors.torch
import numpy as np import numpy as np
@ -32,7 +34,7 @@ from einops import rearrange
from comfy.cli_args import args from comfy.cli_args import args
import json import json
import time import time
import mmap import threading
import warnings import warnings
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
@ -81,14 +83,17 @@ _TYPES = {
} }
def load_safetensors(ckpt): def load_safetensors(ckpt):
f = open(ckpt, "rb") import comfy_aimdo.model_mmap
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
mv = memoryview(mapping)
header_size = struct.unpack("<Q", mapping[:8])[0] f = open(ckpt, "rb", buffering=0)
header = json.loads(mapping[8:8+header_size].decode("utf-8")) model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
mv = mv[8 + header_size:] header_size = struct.unpack("<Q", mv[:8])[0]
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
mv = mv[(data_base_offset := 8 + header_size):]
sd = {} sd = {}
for name, info in header.items(): for name, info in header.items():
@ -102,7 +107,14 @@ def load_safetensors(ckpt):
with warnings.catch_warnings(): with warnings.catch_warnings():
#We are working with read-only RAM by design #We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable") warnings.filterwarnings("ignore", message="The given buffer is not writable")
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"]) tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
storage = tensor.untyped_storage()
setattr(storage,
"_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor
return sd, header.get("__metadata__", {}), return sd, header.get("__metadata__", {}),
@ -885,6 +897,10 @@ def set_attr(obj, attr, value):
return prev return prev
def set_attr_param(obj, attr, value): def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
if (not torch.is_inference_mode_enabled()) and value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value): def set_attr_buffer(obj, attr, value):

View File

@ -15,6 +15,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes "max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}}, "extension": {"manager": {"supports_v4": True}},
"node_replacements": True, "node_replacements": True,
"assets": args.enable_assets,
} }

View File

@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__() super().__init__()
self.node_replacement = self.NodeReplacement() self.node_replacement = self.NodeReplacement()
self.execution = self.Execution() self.execution = self.Execution()
self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton): class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None: async def register(self, node_replace: io.NodeReplace) -> None:
@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display, image=to_display,
) )
class Caching(ProxiedSingleton):
"""
External cache provider API for sharing cached node outputs
across ComfyUI instances.
Example::
from comfy_api.latest import Caching
class MyCacheProvider(Caching.CacheProvider):
async def on_lookup(self, context):
... # check external storage
async def on_store(self, context, value):
... # store to external storage
Caching.register_provider(MyCacheProvider())
"""
from ._caching import CacheProvider, CacheContext, CacheValue
async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Register an external cache provider. Providers are called in registration order."""
from comfy_execution.cache_provider import register_cache_provider
register_cache_provider(provider)
async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Unregister a previously registered cache provider."""
from comfy_execution.cache_provider import unregister_cache_provider
unregister_cache_provider(provider)
class ComfyExtension(ABC): class ComfyExtension(ABC):
async def on_load(self) -> None: async def on_load(self) -> None:
""" """
@ -116,6 +147,9 @@ class Types:
VOXEL = VOXEL VOXEL = VOXEL
File3D = File3D File3D = File3D
Caching = ComfyAPI_latest.Caching
ComfyAPI = ComfyAPI_latest ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API # Create a synchronous version of the API
@ -135,6 +169,7 @@ __all__ = [
"Input", "Input",
"InputImpl", "InputImpl",
"Types", "Types",
"Caching",
"ComfyExtension", "ComfyExtension",
"io", "io",
"IO", "IO",

View File

@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
@dataclass
class CacheContext:
node_id: str
class_type: str
cache_key_hash: str # SHA256 hex digest
@dataclass
class CacheValue:
outputs: list
ui: dict = None
class CacheProvider(ABC):
"""Abstract base class for external cache providers.
Exceptions from provider methods are caught by the caller and never break execution.
"""
@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
pass
@abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""Called after local store. Dispatched via asyncio.create_task."""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""Return False to skip external caching for this node. Default: True."""
return True
def on_prompt_start(self, prompt_id: str) -> None:
pass
def on_prompt_end(self, prompt_id: str) -> None:
pass

View File

@ -272,7 +272,7 @@ class VideoFromFile(VideoInput):
has_first_frame = False has_first_frame = False
for frame in frames: for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = int(offset_seconds * audio_stream.sample_rate) to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
if to_skip < frame.samples: if to_skip < frame.samples:
has_first_frame = True has_first_frame = True
break break
@ -280,7 +280,7 @@ class VideoFromFile(VideoInput):
audio_frames.append(frame.to_ndarray()[..., to_skip:]) audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames: for frame in frames:
if frame.time > start_time + self.__duration: if self.__duration and frame.time > start_time + self.__duration:
break break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0: if len(audio_frames) > 0:

View File

@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
'''Float input.''' '''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None, display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min self.min = min

View File

@ -0,0 +1,68 @@
from pydantic import BaseModel, Field
class RevePostprocessingOperation(BaseModel):
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
upscale_factor: int | None = Field(
None,
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
ge=2,
le=4,
)
class ReveImageCreateRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageEditRequest(BaseModel):
edit_instruction: str = Field(...)
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageRemixRequest(BaseModel):
prompt: str = Field(...)
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageResponse(BaseModel):
image: str | None = Field(None, description="The base64 encoded image data.")
request_id: str | None = Field(None, description="A unique id for the request.")
credits_used: float | None = Field(None, description="The number of credits used for this request.")
version: str | None = Field(None, description="The specific model version used.")
content_violation: bool | None = Field(
None, description="Indicates whether the generated image violates the content policy."
)

View File

@ -1,3 +1,7 @@
import zipfile
from io import BytesIO
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, Types from comfy_api.latest import IO, ComfyExtension, Input, Types
@ -17,7 +21,10 @@ from comfy_api_nodes.apis.hunyuan3d import (
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
bytesio_to_image_tensor,
download_url_to_bytesio,
download_url_to_file_3d, download_url_to_file_3d,
download_url_to_image_tensor,
downscale_image_tensor_by_max_side, downscale_image_tensor_by_max_side,
poll_op, poll_op,
sync_op, sync_op,
@ -36,6 +43,68 @@ def _is_tencent_rate_limited(status: int, body: object) -> bool:
) )
class ObjZipResult:
__slots__ = ("obj", "texture", "metallic", "normal", "roughness")
def __init__(
self,
obj: Types.File3D,
texture: Input.Image | None = None,
metallic: Input.Image | None = None,
normal: Input.Image | None = None,
roughness: Input.Image | None = None,
):
self.obj = obj
self.texture = texture
self.metallic = metallic
self.normal = normal
self.roughness = roughness
async def download_and_extract_obj_zip(url: str) -> ObjZipResult:
"""The Tencent API returns OBJ results as ZIP archives containing the .obj mesh, and texture images.
When PBR is enabled, the ZIP may contain additional metallic, normal, and roughness maps
identified by their filename suffixes.
"""
data = BytesIO()
await download_url_to_bytesio(url, data)
data.seek(0)
if not zipfile.is_zipfile(data):
data.seek(0)
return ObjZipResult(obj=Types.File3D(source=data, file_format="obj"))
data.seek(0)
obj_bytes = None
textures: dict[str, Input.Image] = {}
with zipfile.ZipFile(data) as zf:
for name in zf.namelist():
lower = name.lower()
if lower.endswith(".obj"):
obj_bytes = zf.read(name)
elif any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
stem = lower.rsplit(".", 1)[0]
tensor = bytesio_to_image_tensor(BytesIO(zf.read(name)), mode="RGB")
matched_key = "texture"
for suffix, key in {
"_metallic": "metallic",
"_normal": "normal",
"_roughness": "roughness",
}.items():
if stem.endswith(suffix):
matched_key = key
break
textures[matched_key] = tensor
if obj_bytes is None:
raise ValueError("ZIP archive does not contain an OBJ file.")
return ObjZipResult(
obj=Types.File3D(source=BytesIO(obj_bytes), file_format="obj"),
texture=textures.get("texture"),
metallic=textures.get("metallic"),
normal=textures.get("normal"),
roughness=textures.get("roughness"),
)
def get_file_from_response( def get_file_from_response(
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
) -> ResultFile3D | None: ) -> ResultFile3D | None:
@ -93,6 +162,7 @@ class TencentTextToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"), IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"), IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
], ],
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
@ -151,14 +221,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse, response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status, status_extractor=lambda r: r.Status,
) )
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput( return IO.NodeOutput(
f"{task_id}.glb", f"{task_id}.glb",
await download_url_to_file_3d( await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
), ),
await download_url_to_file_3d( obj_result.obj,
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id obj_result.texture,
),
) )
@ -211,6 +281,10 @@ class TencentImageToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"), IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"), IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
IO.Image.Output(display_name="optional_metallic"),
IO.Image.Output(display_name="optional_normal"),
IO.Image.Output(display_name="optional_roughness"),
], ],
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
@ -304,14 +378,17 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse, response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status, status_extractor=lambda r: r.Status,
) )
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput( return IO.NodeOutput(
f"{task_id}.glb", f"{task_id}.glb",
await download_url_to_file_3d( await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
), ),
await download_url_to_file_3d( obj_result.obj,
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id obj_result.texture,
), obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
) )
@ -431,7 +508,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
], ],
outputs=[ outputs=[
IO.File3DGLB.Output(display_name="GLB"), IO.File3DGLB.Output(display_name="GLB"),
IO.File3DFBX.Output(display_name="FBX"), IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
], ],
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
@ -480,7 +558,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
) )
return IO.NodeOutput( return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "texture_image").Url),
) )
@ -654,7 +733,7 @@ class TencentHunyuan3DExtension(ComfyExtension):
TencentTextToModelNode, TencentTextToModelNode,
TencentImageToModelNode, TencentImageToModelNode,
TencentModelTo3DUVNode, TencentModelTo3DUVNode,
# Tencent3DTextureEditNode, Tencent3DTextureEditNode,
Tencent3DPartNode, Tencent3DPartNode,
TencentSmartTopologyNode, TencentSmartTopologyNode,
] ]

View File

@ -1459,6 +1459,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
node_id="KlingOmniProEditVideoNode", node_id="KlingOmniProEditVideoNode",
display_name="Kling 3.0 Omni Edit Video", display_name="Kling 3.0 Omni Edit Video",
category="api node/video/Kling", category="api node/video/Kling",
essentials_category="Video Generation",
description="Edit an existing video with the latest model from Kling.", description="Edit an existing video with the latest model from Kling.",
inputs=[ inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),

View File

@ -833,6 +833,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
node_id="RecraftVectorizeImageNode", node_id="RecraftVectorizeImageNode",
display_name="Recraft Vectorize Image", display_name="Recraft Vectorize Image",
category="api node/image/Recraft", category="api node/image/Recraft",
essentials_category="Image Tools",
description="Generates SVG synchronously from an input image.", description="Generates SVG synchronously from an input image.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),

View File

@ -0,0 +1,395 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.reve import (
ReveImageCreateRequest,
ReveImageEditRequest,
ReveImageRemixRequest,
RevePostprocessingOperation,
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
sync_op_raw,
tensor_to_base64_string,
validate_string,
)
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
ops = []
if upscale["upscale"] == "enabled":
ops.append(
RevePostprocessingOperation(
process="upscale",
upscale_factor=upscale["upscale_factor"],
)
)
if remove_background:
ops.append(RevePostprocessingOperation(process="remove_background"))
return ops or None
def _postprocessing_inputs():
return [
IO.DynamicCombo.Input(
"upscale",
options=[
IO.DynamicCombo.Option("disabled", []),
IO.DynamicCombo.Option(
"enabled",
[
IO.Int.Input(
"upscale_factor",
default=2,
min=2,
max=4,
step=1,
tooltip="Upscale factor (2x, 3x, or 4x).",
),
],
),
],
tooltip="Upscale the generated image. May add additional cost.",
),
IO.Boolean.Input(
"remove_background",
default=False,
tooltip="Remove the background from the generated image. May add additional cost.",
),
]
def _reve_price_extractor(headers: dict) -> float | None:
credits_used = headers.get("x-reve-credits-used")
if credits_used is not None:
return float(credits_used) / 524.48
return None
def _reve_response_header_validator(headers: dict) -> None:
error_code = headers.get("x-reve-error-code")
if error_code:
raise ValueError(f"Reve API error: {error_code}")
if headers.get("x-reve-content-violation", "").lower() == "true":
raise ValueError("The generated image was flagged for content policy violation.")
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
return [
IO.DynamicCombo.Option(
version,
[
IO.Combo.Input(
"aspect_ratio",
options=aspect_ratios,
tooltip="Aspect ratio of the output image.",
),
IO.Int.Input(
"test_time_scaling",
default=1,
min=1,
max=5,
step=1,
tooltip="Higher values produce better images but cost more credits.",
advanced=True,
),
],
)
for version in versions
]
class ReveImageCreateNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageCreateNode",
display_name="Reve Image Create",
category="api node/image/Reve",
description="Generate images from text descriptions using Reve.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-create@20250915"],
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for generation.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/create",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageCreateRequest(
prompt=prompt,
aspect_ratio=model["aspect_ratio"],
version=model["model"],
test_time_scaling=model["test_time_scaling"],
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageEditNode",
display_name="Reve Image Edit",
category="api node/image/Reve",
description="Edit images using natural language instructions with Reve.",
inputs=[
IO.Image.Input("image", tooltip="The image to edit."),
IO.String.Input(
"edit_instruction",
multiline=True,
default="",
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-edit@20250915", "reve-edit-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for editing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
edit_instruction: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(edit_instruction, min_length=1, max_length=2560)
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/edit",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageEditRequest(
edit_instruction=edit_instruction,
reference_image=tensor_to_base64_string(image),
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageRemixNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageRemixNode",
display_name="Reve Image Remix",
category="api node/image/Reve",
description="Combine reference images with text prompts to create new images using Reve.",
inputs=[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="image_",
min=1,
max=6,
),
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. "
"May include XML img tags to reference specific images by index, "
"e.g. <img>0</img>, <img>1</img>, etc.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-remix@20250915", "reve-remix-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for remixing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
reference_images: IO.Autogrow.Type,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
if not reference_images:
raise ValueError("At least one reference image is required.")
ref_base64_list = []
for key in reference_images:
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
if len(ref_base64_list) > 6:
raise ValueError("Maximum 6 reference images are allowed.")
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/remix",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageRemixRequest(
prompt=prompt,
reference_images=ref_base64_list,
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ReveImageCreateNode,
ReveImageEditNode,
ReveImageRemixNode,
]
async def comfy_entrypoint() -> ReveExtension:
return ReveExtension()

View File

@ -67,6 +67,7 @@ class _RequestConfig:
progress_origin_ts: float | None = None progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None is_rate_limited: Callable[[int, Any], bool] | None = None
response_header_validator: Callable[[dict[str, str]], None] | None = None
@dataclass @dataclass
@ -202,11 +203,13 @@ async def sync_op_raw(
monitor_progress: bool = True, monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16, max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None, is_rate_limited: Callable[[int, Any], bool] | None = None,
response_header_validator: Callable[[dict[str, str]], None] | None = None,
) -> dict[str, Any] | bytes: ) -> dict[str, Any] | bytes:
""" """
Make a single network request. Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON). - If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
- If as_binary=True: returns bytes. - If as_binary=True: returns bytes.
- response_header_validator: optional callback receiving response headers dict
""" """
if isinstance(data, BaseModel): if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True) data = data.model_dump(exclude_none=True)
@ -232,6 +235,7 @@ async def sync_op_raw(
price_extractor=price_extractor, price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit, max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited, is_rate_limited=is_rate_limited,
response_header_validator=response_header_validator,
) )
return await _request_base(cfg, expect_binary=as_binary) return await _request_base(cfg, expect_binary=as_binary)
@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
) )
bytes_payload = bytes(buff) bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
if cfg.price_extractor:
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(resp_headers)
if cfg.response_header_validator:
cfg.response_header_validator(resp_headers)
operation_succeeded = True operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time) final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response( request_logger.log_request_response(
@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_method=method, request_method=method,
request_url=url, request_url=url,
response_status_code=resp.status, response_status_code=resp.status,
response_headers=dict(resp.headers), response_headers=resp_headers,
response_content=bytes_payload, response_content=bytes_payload,
) )
return bytes_payload return bytes_payload

View File

@ -0,0 +1,138 @@
from typing import Any, Optional, Tuple, List
import hashlib
import json
import logging
import threading
# Public types — source of truth is comfy_api.latest._caching
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
_logger = logging.getLogger(__name__)
_providers: List[CacheProvider] = []
_providers_lock = threading.Lock()
_providers_snapshot: Tuple[CacheProvider, ...] = ()
def register_cache_provider(provider: CacheProvider) -> None:
"""Register an external cache provider. Providers are called in registration order."""
global _providers_snapshot
with _providers_lock:
if provider in _providers:
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
return
_providers.append(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
def unregister_cache_provider(provider: CacheProvider) -> None:
global _providers_snapshot
with _providers_lock:
try:
_providers.remove(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError:
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
return _providers_snapshot
def _has_cache_providers() -> bool:
return bool(_providers_snapshot)
def _clear_cache_providers() -> None:
global _providers_snapshot
with _providers_lock:
_providers.clear()
_providers_snapshot = ()
def _canonicalize(obj: Any) -> Any:
# Convert to canonical JSON-serializable form with deterministic ordering.
# Frozensets have non-deterministic iteration order between Python sessions.
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
# _serialize_cache_key returns None and external caching is skipped.
if isinstance(obj, frozenset):
return ("__frozenset__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, set):
return ("__set__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, tuple):
return ("__tuple__", [_canonicalize(item) for item in obj])
elif isinstance(obj, list):
return [_canonicalize(item) for item in obj]
elif isinstance(obj, dict):
return {"__dict__": sorted(
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
key=lambda x: json.dumps(x, sort_keys=True)
)}
elif isinstance(obj, (int, float, str, bool, type(None))):
return (type(obj).__name__, obj)
elif isinstance(obj, bytes):
return ("__bytes__", obj.hex())
else:
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
# Returns deterministic SHA256 hex digest, or None on failure.
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
try:
canonical = _canonicalize(cache_key)
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
except Exception as e:
_logger.warning(f"Failed to serialize cache key: {e}")
return None
def _contains_self_unequal(obj: Any) -> bool:
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
# never hit locally, but serialized form would match externally. Skip these.
try:
if not (obj == obj):
return True
except Exception:
return True
if isinstance(obj, (frozenset, tuple, list, set)):
return any(_contains_self_unequal(item) for item in obj)
if isinstance(obj, dict):
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
if hasattr(obj, 'value'):
return _contains_self_unequal(obj.value)
return False
def _estimate_value_size(value: CacheValue) -> int:
try:
import torch
except ImportError:
return 0
total = 0
def estimate(obj):
nonlocal total
if isinstance(obj, torch.Tensor):
total += obj.numel() * obj.element_size()
elif isinstance(obj, dict):
for v in obj.values():
estimate(v)
elif isinstance(obj, (list, tuple)):
for item in obj:
estimate(item)
for output in value.outputs:
estimate(output)
return total

View File

@ -1,3 +1,4 @@
import asyncio
import bisect import bisect
import gc import gc
import itertools import itertools
@ -147,13 +148,15 @@ class CacheKeySetInputSignature(CacheKeySet):
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache: class BasicCache:
def __init__(self, key_class): def __init__(self, key_class, enable_providers=False):
self.key_class = key_class self.key_class = key_class
self.initialized = False self.initialized = False
self.enable_providers = enable_providers
self.dynprompt: DynamicPrompt self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet self.cache_key_set: CacheKeySet
self.cache = {} self.cache = {}
self.subcaches = {} self.subcaches = {}
self._pending_store_tasks: set = set()
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt self.dynprompt = dynprompt
@ -196,18 +199,138 @@ class BasicCache:
def poll(self, **kwargs): def poll(self, **kwargs):
pass pass
def _set_immediate(self, node_id, value): def get_local(self, node_id):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
def _get_immediate(self, node_id):
if not self.initialized: if not self.initialized:
return None return None
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache: if cache_key in self.cache:
return self.cache[cache_key] return self.cache[cache_key]
else: return None
def set_local(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
async def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
await self._notify_providers_store(node_id, cache_key, value)
async def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
external_result = await self._check_providers_lookup(node_id, cache_key)
if external_result is not None:
self.cache[cache_key] = external_result
return external_result
return None
async def _notify_providers_store(self, node_id, cache_key, value):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not self.enable_providers:
return
if not _has_cache_providers():
return
if not self._is_external_cacheable_value(value):
return
if _contains_self_unequal(cache_key):
return
context = self._build_context(node_id, cache_key)
if context is None:
return
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
for provider in _get_cache_providers():
try:
if provider.should_cache(context, cache_value):
task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
self._pending_store_tasks.add(task)
task.add_done_callback(self._pending_store_tasks.discard)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
@staticmethod
async def _safe_provider_store(provider, context, cache_value):
from comfy_execution.cache_provider import _logger
try:
await provider.on_store(context, cache_value)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
async def _check_providers_lookup(self, node_id, cache_key):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not self.enable_providers:
return None
if not _has_cache_providers():
return None
if _contains_self_unequal(cache_key):
return None
context = self._build_context(node_id, cache_key)
if context is None:
return None
for provider in _get_cache_providers():
try:
if not provider.should_cache(context):
continue
result = await provider.on_lookup(context)
if result is not None:
if not isinstance(result, CacheValue):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
continue
if not isinstance(result.outputs, (list, tuple)):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
continue
from execution import CacheEntry
return CacheEntry(ui=result.ui, outputs=list(result.outputs))
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
return None
def _is_external_cacheable_value(self, value):
return hasattr(value, 'outputs') and hasattr(value, 'ui')
def _get_class_type(self, node_id):
if not self.initialized or not self.dynprompt:
return ''
try:
return self.dynprompt.get_node(node_id).get('class_type', '')
except Exception:
return ''
def _build_context(self, node_id, cache_key):
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
try:
cache_key_hash = _serialize_cache_key(cache_key)
if cache_key_hash is None:
return None
return CacheContext(
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key_hash=cache_key_hash,
)
except Exception as e:
_logger.warning(f"Failed to build cache context for node {node_id}: {e}")
return None return None
async def _ensure_subcache(self, node_id, children_ids): async def _ensure_subcache(self, node_id, children_ids):
@ -236,8 +359,8 @@ class BasicCache:
return result return result
class HierarchicalCache(BasicCache): class HierarchicalCache(BasicCache):
def __init__(self, key_class): def __init__(self, key_class, enable_providers=False):
super().__init__(key_class) super().__init__(key_class, enable_providers=enable_providers)
def _get_cache_for(self, node_id): def _get_cache_for(self, node_id):
assert self.dynprompt is not None assert self.dynprompt is not None
@ -257,16 +380,27 @@ class HierarchicalCache(BasicCache):
return None return None
return cache return cache
def get(self, node_id): async def get(self, node_id):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
if cache is None: if cache is None:
return None return None
return cache._get_immediate(node_id) return await cache._get_immediate(node_id)
def set(self, node_id, value): def get_local(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return BasicCache.get_local(cache, node_id)
async def set(self, node_id, value):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
assert cache is not None assert cache is not None
cache._set_immediate(node_id, value) await cache._set_immediate(node_id, value)
def set_local(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
BasicCache.set_local(cache, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
@ -287,18 +421,24 @@ class NullCache:
def poll(self, **kwargs): def poll(self, **kwargs):
pass pass
def get(self, node_id): async def get(self, node_id):
return None return None
def set(self, node_id, value): def get_local(self, node_id):
return None
async def set(self, node_id, value):
pass
def set_local(self, node_id, value):
pass pass
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
return self return self
class LRUCache(BasicCache): class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100): def __init__(self, key_class, max_size=100, enable_providers=False):
super().__init__(key_class) super().__init__(key_class, enable_providers=enable_providers)
self.max_size = max_size self.max_size = max_size
self.min_generation = 0 self.min_generation = 0
self.generation = 0 self.generation = 0
@ -322,18 +462,18 @@ class LRUCache(BasicCache):
del self.children[key] del self.children[key]
self._clean_subcaches() self._clean_subcaches()
def get(self, node_id): async def get(self, node_id):
self._mark_used(node_id) self._mark_used(node_id)
return self._get_immediate(node_id) return await self._get_immediate(node_id)
def _mark_used(self, node_id): def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None: if cache_key is not None:
self.used_generation[cache_key] = self.generation self.used_generation[cache_key] = self.generation
def set(self, node_id, value): async def set(self, node_id, value):
self._mark_used(node_id) self._mark_used(node_id)
return self._set_immediate(node_id, value) return await self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes # Just uses subcaches for tracking 'live' nodes
@ -366,20 +506,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache): class RAMPressureCache(LRUCache):
def __init__(self, key_class): def __init__(self, key_class, enable_providers=False):
super().__init__(key_class, 0) super().__init__(key_class, 0, enable_providers=enable_providers)
self.timestamps = {} self.timestamps = {}
def clean_unused(self): def clean_unused(self):
self._clean_subcaches() self._clean_subcaches()
def set(self, node_id, value): async def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value) await super().set(node_id, value)
def get(self, node_id): async def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id) return await super().get(node_id)
def poll(self, ram_headroom): def poll(self, ram_headroom):
def _ram_gb(): def _ram_gb():

View File

@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners = {} self.execution_cache_listeners = {}
def is_cached(self, node_id): def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None return self.output_cache.get_local(node_id) is not None
def cache_link(self, from_node_id, to_node_id): def cache_link(self, from_node_id, to_node_id):
if to_node_id not in self.execution_cache: if to_node_id not in self.execution_cache:
self.execution_cache[to_node_id] = {} self.execution_cache[to_node_id] = {}
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
if from_node_id not in self.execution_cache_listeners: if from_node_id not in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id) self.execution_cache_listeners[from_node_id].add(to_node_id)
@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
if value is None: if value is None:
return None return None
#Write back to the main cache on touch. #Write back to the main cache on touch.
self.output_cache.set(from_node_id, value) self.output_cache.set_local(from_node_id, value)
return value return value
def cache_update(self, node_id, value): def cache_update(self, node_id, value):

View File

@ -19,6 +19,7 @@ class EmptyLatentAudio(IO.ComfyNode):
node_id="EmptyLatentAudio", node_id="EmptyLatentAudio",
display_name="Empty Latent Audio", display_name="Empty Latent Audio",
category="latent/audio", category="latent/audio",
essentials_category="Audio",
inputs=[ inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
IO.Int.Input( IO.Int.Input(
@ -185,6 +186,7 @@ class SaveAudioMP3(IO.ComfyNode):
search_aliases=["export mp3"], search_aliases=["export mp3"],
display_name="Save Audio (MP3)", display_name="Save Audio (MP3)",
category="audio", category="audio",
essentials_category="Audio",
inputs=[ inputs=[
IO.Audio.Input("audio"), IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"), IO.String.Input("filename_prefix", default="audio/ComfyUI"),

View File

@ -6,6 +6,7 @@ import comfy.model_management
import torch import torch
import math import math
import nodes import nodes
import comfy.ldm.flux.math
class CLIPTextEncodeFlux(io.ComfyNode): class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod @classmethod
@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode):
sigmas = get_schedule(steps, round(seq_len)) sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas) return io.NodeOutput(sigmas)
class KV_Attn_Input:
def __init__(self):
self.cache = {}
def __call__(self, q, k, v, extra_options, **kwargs):
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
if len(reference_image_num_tokens) == 0:
return {}
ref_toks = sum(reference_image_num_tokens)
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
if cache_key in self.cache:
kk, vv = self.cache[cache_key]
self.set_cache = False
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
self.set_cache = True
return {"q": q, "k": k, "v": v}
def cleanup(self):
self.cache = {}
class FluxKVCache(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="FluxKVCache",
display_name="Flux KV Cache",
description="Enables KV Cache optimization for reference images on Flux family models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to use KV Cache on."),
],
outputs=[
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
],
)
@classmethod
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
m = model.clone()
input_patch_obj = KV_Attn_Input()
def model_input_patch(inputs):
if len(input_patch_obj.cache) > 0:
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
if ref_image_tokens > 0:
img = inputs["img"]
inputs["img"] = img[:, :-ref_image_tokens]
return inputs
m.set_model_attn1_patch(input_patch_obj)
m.set_model_post_input_patch(model_input_patch)
if hasattr(model.model.diffusion_model, "params"):
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
else:
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
return io.NodeOutput(m)
class FluxExtension(ComfyExtension): class FluxExtension(ComfyExtension):
@override @override
@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension):
FluxKontextMultiReferenceLatentMethod, FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage, EmptyFlux2LatentImage,
Flux2Scheduler, Flux2Scheduler,
FluxKVCache,
] ]

View File

@ -14,6 +14,7 @@ class ImageCompare(IO.ComfyNode):
display_name="Image Compare", display_name="Image Compare",
description="Compares two images side by side with a slider.", description="Compares two images side by side with a slider.",
category="image", category="image",
essentials_category="Image Tools",
is_experimental=True, is_experimental=True,
is_output_node=True, is_output_node=True,
inputs=[ inputs=[

View File

@ -58,6 +58,7 @@ class ImageCropV2(IO.ComfyNode):
search_aliases=["trim"], search_aliases=["trim"],
display_name="Image Crop", display_name="Image Crop",
category="image/transform", category="image/transform",
essentials_category="Image Tools",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"), IO.BoundingBox.Input("crop_region", component="ImageCrop"),

View File

@ -0,0 +1,127 @@
from __future__ import annotations
import hashlib
import os
import numpy as np
import torch
from PIL import Image
import folder_paths
import node_helpers
from comfy_api.latest import ComfyExtension, io, UI
from typing_extensions import override
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
hex_color = hex_color.lstrip("#")
if len(hex_color) != 6:
return (0.0, 0.0, 0.0)
r = int(hex_color[0:2], 16) / 255.0
g = int(hex_color[2:4], 16) / 255.0
b = int(hex_color[4:6], 16) / 255.0
return (r, g, b)
class PainterNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Painter",
display_name="Painter",
category="image",
inputs=[
io.Image.Input(
"image",
optional=True,
tooltip="Optional base image to paint over",
),
io.String.Input(
"mask",
default="",
socketless=True,
extra_dict={"widgetType": "PAINTER", "image_upload": True},
),
io.Int.Input(
"width",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.Int.Input(
"height",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.Color.Input("bg_color", default="#000000"),
],
outputs=[
io.Image.Output("IMAGE"),
io.Mask.Output("MASK"),
],
)
@classmethod
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
if image is not None:
base_image = image[:1]
h, w = base_image.shape[1], base_image.shape[2]
else:
h, w = height, width
r, g, b = hex_to_rgb(bg_color)
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
base_image[0, :, :, 0] = r
base_image[0, :, :, 1] = g
base_image[0, :, :, 2] = b
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
painter_img = node_helpers.pillow(Image.open, mask_path)
painter_img = painter_img.convert("RGBA")
if painter_img.size != (w, h):
painter_img = painter_img.resize((w, h), Image.LANCZOS)
painter_np = np.array(painter_img).astype(np.float32) / 255.0
painter_rgb = painter_np[:, :, :3]
painter_alpha = painter_np[:, :, 3:4]
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
base_np = base_image[0].cpu().numpy()
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
out_image = torch.from_numpy(composited).unsqueeze(0)
else:
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
out_image = base_image
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
@classmethod
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
if os.path.exists(mask_path):
m = hashlib.sha256()
with open(mask_path, "rb") as f:
m.update(f.read())
return m.digest().hex()
return ""
class PainterExtension(ComfyExtension):
@override
async def get_node_list(self):
return [PainterNode]
async def comfy_entrypoint():
return PainterExtension()

View File

@ -21,6 +21,7 @@ class Blend(io.ComfyNode):
node_id="ImageBlend", node_id="ImageBlend",
display_name="Image Blend", display_name="Image Blend",
category="image/postprocessing", category="image/postprocessing",
essentials_category="Image Tools",
inputs=[ inputs=[
io.Image.Input("image1"), io.Image.Input("image1"),
io.Image.Input("image2"), io.Image.Input("image2"),

View File

@ -15,6 +15,7 @@ import comfy.sampler_helpers
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy_extras.nodes_custom_sampler import comfy_extras.nodes_custom_sampler
import folder_paths import folder_paths
import node_helpers import node_helpers
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
training_dtype=torch.bfloat16, training_dtype=torch.bfloat16,
real_dataset=None, real_dataset=None,
bucket_latents=None, bucket_latents=None,
use_grad_scaler=False,
): ):
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = optimizer self.optimizer = optimizer
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
self.bucket_latents: list[torch.Tensor] | None = ( self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi) bucket_latents # list of (Bi, C, Hi, Wi)
) )
# GradScaler for fp16 training
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
# Precompute bucket offsets and weights for sampling # Precompute bucket offsets and weights for sampling
if bucket_latents is not None: if bucket_latents is not None:
self._init_bucket_data(bucket_latents) self._init_bucket_data(bucket_latents)
@ -204,9 +208,12 @@ class TrainSampler(comfy.samplers.Sampler):
batch_sigmas.requires_grad_(True), batch_sigmas.requires_grad_(True),
**batch_extra_args, **batch_extra_args,
) )
loss = self.loss_fn(x0_pred, x0) loss = self.loss_fn(x0_pred.float(), x0.float())
if bwd: if bwd:
bwd_loss = loss / self.grad_acc bwd_loss = loss / self.grad_acc
if self.grad_scaler is not None:
self.grad_scaler.scale(bwd_loss).backward()
else:
bwd_loss.backward() bwd_loss.backward()
return loss return loss
@ -307,6 +314,9 @@ class TrainSampler(comfy.samplers.Sampler):
) )
total_loss += loss total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies) total_loss = total_loss / self.grad_acc / len(indicies)
if self.grad_scaler is not None:
self.grad_scaler.scale(total_loss).backward()
else:
total_loss.backward() total_loss.backward()
if self.loss_callback: if self.loss_callback:
self.loss_callback(total_loss.item()) self.loss_callback(total_loss.item())
@ -348,11 +358,17 @@ class TrainSampler(comfy.samplers.Sampler):
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0: if (i + 1) % self.grad_acc == 0:
if self.grad_scaler is not None:
self.grad_scaler.unscale_(self.optimizer)
for param_groups in self.optimizer.param_groups: for param_groups in self.optimizer.param_groups:
for param in param_groups["params"]: for param in param_groups["params"]:
if param.grad is None: if param.grad is None:
continue continue
param.grad.data = param.grad.data.to(param.data.dtype) param.grad.data = param.grad.data.to(param.data.dtype)
if self.grad_scaler is not None:
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
self.optimizer.step() self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
ui_pbar.update(1) ui_pbar.update(1)
@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
), ),
io.Combo.Input( io.Combo.Input(
"training_dtype", "training_dtype",
options=["bf16", "fp32"], options=["bf16", "fp32", "none"],
default="bf16", default="bf16",
tooltip="The dtype to use for training.", tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
), ),
io.Combo.Input( io.Combo.Input(
"lora_dtype", "lora_dtype",
@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
io.Boolean.Input( io.Boolean.Input(
"offloading", "offloading",
default=False, default=False,
tooltip="Offload the Model to RAM. Requires Bypass Mode.", tooltip="Offload model weights to CPU during training to save GPU memory.",
), ),
io.Combo.Input( io.Combo.Input(
"existing_lora", "existing_lora",
@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
# Setup model and dtype # Setup model and dtype
mp = model.clone() mp = model.clone()
use_grad_scaler = False
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype) dtype = node_helpers.string_to_torch_dtype(training_dtype)
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype) mp.set_model_compute_dtype(dtype)
else:
if mp.is_dynamic(): # Detect model's native dtype for autocast
if not bypass_mode: model_dtype = mp.model.get_dtype()
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode") if model_dtype == torch.float16:
bypass_mode = True dtype = torch.float16
offloading = True use_grad_scaler = True
elif offloading: # Warn about fp16 accumulation instability during training
if not bypass_mode: if PerformanceFeature.Fp16Accumulation in args.fast:
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message") logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
# Prepare latents and compute counts # Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count( latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode latents, latents_dtype, bucket_mode
) )
# Validate and expand conditioning # Validate and expand conditioning
@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed, seed=seed,
training_dtype=dtype, training_dtype=dtype,
bucket_latents=latents, bucket_latents=latents,
use_grad_scaler=use_grad_scaler,
) )
else: else:
train_sampler = TrainSampler( train_sampler = TrainSampler(
@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed, seed=seed,
training_dtype=dtype, training_dtype=dtype,
real_dataset=latents if multi_res else None, real_dataset=latents if multi_res else None,
use_grad_scaler=use_grad_scaler,
) )
# Setup guider # Setup guider
@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
io.Int.Input( io.Int.Input(
"steps", "steps",
optional=True, optional=True,
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
), ),
], ],
outputs=[], outputs=[],

View File

@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode):
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar) s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False oom = False
except model_management.OOM_EXCEPTION as e: except Exception as e:
model_management.raise_non_oom(e)
tile //= 2 tile //= 2
if tile < 128: if tile < 128:
raise e raise e

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.16.3" __version__ = "0.17.0"

View File

@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io from comfy_api.latest import io, _io
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
class ExecutionResult(Enum): class ExecutionResult(Enum):
@ -126,15 +127,15 @@ class CacheSet:
# Performs like the old cache -- dump data ASAP # Performs like the old cache -- dump data ASAP
def init_classic_cache(self): def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature) self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size): def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom): def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature) self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self): def init_null_cache(self):
@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs'] inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type'] class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = caches.outputs.get(unique_id) cached = await caches.outputs.get(unique_id)
if cached is not None: if cached is not None:
if server.client_id is not None: if server.client_id is not None:
cached_ui = cached.ui or {} cached_ui = cached.ui or {}
@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
server.last_node_id = display_node_id server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = caches.objects.get(unique_id) obj = await caches.objects.get(unique_id)
if obj is None: if obj is None:
obj = class_def() obj = class_def()
caches.objects.set(unique_id, obj) await caches.objects.set(unique_id, obj)
if issubclass(class_def, _ComfyNodeInternal): if issubclass(class_def, _ComfyNodeInternal):
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry) execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry) await caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex: except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted") logging.info("Processing interrupted")
@ -612,7 +613,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
logging.error(traceback.format_exc()) logging.error(traceback.format_exc())
tips = "" tips = ""
if isinstance(ex, comfy.model_management.OOM_EXCEPTION): if comfy.model_management.is_oom(ex):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number." tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary())) logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.") logging.error("Got an OOM, unloading all loaded models.")
@ -684,6 +685,19 @@ class PromptExecutor:
} }
self.add_message("execution_error", mes, broadcast=False) self.add_message("execution_error", mes, broadcast=False)
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
if not _has_cache_providers():
return
for provider in _get_cache_providers():
try:
if event == "start":
provider.on_prompt_start(prompt_id)
elif event == "end":
provider.on_prompt_end(prompt_id)
except Exception as e:
_cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
@ -700,6 +714,9 @@ class PromptExecutor:
self.status_messages = [] self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
self._notify_prompt_lifecycle("start", prompt_id)
try:
with torch.inference_mode(): with torch.inference_mode():
dynamic_prompt = DynamicPrompt(prompt) dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt) reset_progress_state(prompt_id, dynamic_prompt)
@ -709,10 +726,14 @@ class PromptExecutor:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache) await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused() cache.clean_unused()
cached_nodes = [] node_ids = list(prompt.keys())
for node_id in prompt: cache_results = await asyncio.gather(
if self.caches.outputs.get(node_id) is not None: *(self.caches.outputs.get(node_id) for node_id in node_ids)
cached_nodes.append(node_id) )
cached_nodes = [
node_id for node_id, result in zip(node_ids, cache_results)
if result is not None
]
comfy.model_management.cleanup_models_gc() comfy.model_management.cleanup_models_gc()
self.add_message("execution_cached", self.add_message("execution_cached",
@ -760,6 +781,8 @@ class PromptExecutor:
self.server.last_node_id = None self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY: if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models() comfy.model_management.unload_all_models()
finally:
self._notify_prompt_lifecycle("end", prompt_id)
async def validate_inputs(prompt_id, prompt, item, validated): async def validate_inputs(prompt_id, prompt, item, validated):

59
main.py
View File

@ -3,18 +3,22 @@ comfy.options.enable_args_parsing()
import os import os
import importlib.util import importlib.util
import shutil
import importlib.metadata
import folder_paths import folder_paths
import time import time
from comfy.cli_args import args, enables_dynamic_vram from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger from app.logger import setup_logger
from app.assets.scanner import seed_assets
import itertools import itertools
import utils.extra_config import utils.extra_config
from utils.mime_types import init_mime_types
import faulthandler
import logging import logging
import sys import sys
from comfy_execution.progress import get_progress_state from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags from comfy_api import feature_flags
from app.database.db import init_db, dependencies_available
if __name__ == "__main__": if __name__ == "__main__":
#NOTE: These do not do anything on core ComfyUI, they are for custom nodes. #NOTE: These do not do anything on core ComfyUI, they are for custom nodes.
@ -23,6 +27,8 @@ if __name__ == "__main__":
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout) setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
faulthandler.enable(file=sys.stderr, all_threads=False)
import comfy_aimdo.control import comfy_aimdo.control
if enables_dynamic_vram(): if enables_dynamic_vram():
@ -62,8 +68,15 @@ if __name__ == "__main__":
def handle_comfyui_manager_unavailable(): def handle_comfyui_manager_unavailable():
if not args.windows_standalone_build: manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt")
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n") uv_available = shutil.which("uv") is not None
pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}"
msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}"
if uv_available:
msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}"
msg += "\n"
logging.warning(msg)
args.enable_manager = False args.enable_manager = False
@ -161,6 +174,7 @@ def execute_prestartup_script():
logging.info("") logging.info("")
apply_custom_paths() apply_custom_paths()
init_mime_types()
if args.enable_manager: if args.enable_manager:
comfyui_manager.prestartup() comfyui_manager.prestartup()
@ -170,7 +184,6 @@ execute_prestartup_script()
# Main code # Main code
import asyncio import asyncio
import shutil
import threading import threading
import gc import gc
@ -179,6 +192,7 @@ if 'torch' in sys.modules:
import comfy.utils import comfy.utils
from app.assets.seeder import asset_seeder
import execution import execution
import server import server
@ -192,8 +206,8 @@ import hook_breaker_ac10a0
import comfy.memory_management import comfy.memory_management
import comfy.model_patcher import comfy.model_patcher
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl(): if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
if comfy.model_management.torch_version_numeric < (2, 8): if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if args.verbose == 'DEBUG': if args.verbose == 'DEBUG':
@ -258,6 +272,7 @@ def prompt_worker(q, server_instance):
for k in sensitive: for k in sensitive:
extra_data[k] = sensitive[k] extra_data[k] = sensitive[k]
asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4]) e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True need_gc = True
@ -302,6 +317,7 @@ def prompt_worker(q, server_instance):
last_gc_collect = current_time last_gc_collect = current_time
need_gc = False need_gc = False
hook_breaker_ac10a0.restore_functions() hook_breaker_ac10a0.restore_functions()
asset_seeder.resume()
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None): async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
@ -352,12 +368,29 @@ def cleanup_temp():
def setup_database(): def setup_database():
try: try:
from app.database.db import init_db, dependencies_available
if dependencies_available(): if dependencies_available():
init_db() init_db()
if not args.disable_assets_autoscan: if args.enable_assets:
seed_assets(["models"], enable_logging=True) if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
logging.info("Background asset scan initiated for models, input, output")
except Exception as e: except Exception as e:
if "database is locked" in str(e):
logging.error(
"Database is locked. Another ComfyUI process is already using this database.\n"
"To resolve this, specify a separate database file for this instance:\n"
" --database-url sqlite:///path/to/another.db"
)
sys.exit(1)
if args.enable_assets:
logging.error(
f"Failed to initialize database: {e}\n"
"The --enable-assets flag requires a working database connection.\n"
"To resolve this, try one of the following:\n"
" 1. Install the latest requirements: pip install -r requirements.txt\n"
" 2. Specify an alternative database URL: --database-url sqlite:///path/to/your.db\n"
" 3. Use an in-memory database: --database-url sqlite:///:memory:"
)
sys.exit(1)
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}") logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
@ -429,6 +462,11 @@ if __name__ == "__main__":
# Running directly, just start ComfyUI. # Running directly, just start ComfyUI.
logging.info("Python version: {}".format(sys.version)) logging.info("Python version: {}".format(sys.version))
logging.info("ComfyUI version: {}".format(comfyui_version.__version__)) logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
for package in ("comfy-aimdo", "comfy-kitchen"):
try:
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
except:
pass
if sys.version_info.major == 3 and sys.version_info.minor < 10: if sys.version_info.major == 3 and sys.version_info.minor < 10:
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
@ -440,5 +478,6 @@ if __name__ == "__main__":
event_loop.run_until_complete(x) event_loop.run_until_complete(x)
except KeyboardInterrupt: except KeyboardInterrupt:
logging.info("\nStopped server") logging.info("\nStopped server")
finally:
asset_seeder.shutdown()
cleanup_temp() cleanup_temp()

View File

@ -1 +1 @@
comfyui_manager==4.1b1 comfyui_manager==4.1b5

View File

@ -32,7 +32,7 @@ async def cache_control(
) )
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point: if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
response.headers.setdefault("Cache-Control", "no-cache") response.headers.setdefault("Cache-Control", "no-store")
return response return response
# Early return for non-image files - no cache headers needed # Early return for non-image files - no cache headers needed

View File

@ -81,6 +81,7 @@ class CLIPTextEncode(ComfyNodeABC):
class ConditioningCombine: class ConditioningCombine:
ESSENTIALS_CATEGORY = "Image Generation"
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}} return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
@ -1211,9 +1212,6 @@ class GLIGENTextBoxApply:
return (c, ) return (c, )
class EmptyLatentImage: class EmptyLatentImage:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -1232,7 +1230,7 @@ class EmptyLatentImage:
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"] SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
def generate(self, width, height, batch_size=1): def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return ({"samples": latent, "downscale_ratio_spacial": 8}, ) return ({"samples": latent, "downscale_ratio_spacial": 8}, )
@ -1724,6 +1722,8 @@ class LoadImage:
output_masks = [] output_masks = []
w, h = None, None w, h = None, None
dtype = comfy.model_management.intermediate_dtype()
for i in ImageSequence.Iterator(img): for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i) i = node_helpers.pillow(ImageOps.exif_transpose, i)
@ -1748,8 +1748,8 @@ class LoadImage:
mask = 1. - torch.from_numpy(mask) mask = 1. - torch.from_numpy(mask)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image) output_images.append(image.to(dtype=dtype))
output_masks.append(mask.unsqueeze(0)) output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
if img.format == "MPO": if img.format == "MPO":
break # ignore all frames except the first one for MPO format break # ignore all frames except the first one for MPO format
@ -1779,6 +1779,7 @@ class LoadImage:
return True return True
class LoadImageMask: class LoadImageMask:
ESSENTIALS_CATEGORY = "Image Tools"
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
_color_channels = ["alpha", "red", "green", "blue"] _color_channels = ["alpha", "red", "green", "blue"]
@ -1887,6 +1888,7 @@ class ImageScale:
return (s,) return (s,)
class ImageScaleBy: class ImageScaleBy:
ESSENTIALS_CATEGORY = "Image Tools"
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@classmethod @classmethod
@ -2450,6 +2452,7 @@ async def init_builtin_extra_nodes():
"nodes_nag.py", "nodes_nag.py",
"nodes_sdpose.py", "nodes_sdpose.py",
"nodes_math.py", "nodes_math.py",
"nodes_painter.py",
] ]
import_failed = [] import_failed = []

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.16.3" version = "0.17.0"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.19 comfyui-frontend-package==1.41.20
comfyui-workflow-templates==0.9.10 comfyui-workflow-templates==0.9.21
comfyui-embedded-docs==0.4.3 comfyui-embedded-docs==0.4.3
torch torch
torchsde torchsde
@ -20,11 +20,13 @@ tqdm
psutil psutil
alembic alembic
SQLAlchemy SQLAlchemy
filelock
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.7 comfy-kitchen>=0.2.8
comfy-aimdo>=0.2.7 comfy-aimdo>=0.2.12
requests requests
simpleeval>=1.0 simpleeval>=1.0.0
blake3
#non essential dependencies: #non essential dependencies:
kornia>=0.7.1 kornia>=0.7.1

View File

@ -33,8 +33,10 @@ import node_helpers
from comfyui_version import __version__ from comfyui_version import __version__
from app.frontend_management import FrontendManager, parse_version from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal from comfy_api.internal import _ComfyNodeInternal
from app.assets.scanner import seed_assets from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_system from app.assets.api.routes import register_assets_routes
from app.assets.services.ingest import register_file_in_place
from app.assets.services.asset_management import resolve_hash_to_path
from app.user_manager import UserManager from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
@ -197,10 +199,6 @@ class PromptServer():
def __init__(self, loop): def __init__(self, loop):
PromptServer.instance = self PromptServer.instance = self
mimetypes.init()
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
mimetypes.add_type('image/webp', '.webp')
self.user_manager = UserManager() self.user_manager = UserManager()
self.model_file_manager = ModelFileManager() self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager() self.custom_node_manager = CustomNodeManager()
@ -239,7 +237,11 @@ class PromptServer():
else args.front_end_root else args.front_end_root
) )
logging.info(f"[Prompt Server] web root: {self.web_root}") logging.info(f"[Prompt Server] web root: {self.web_root}")
register_assets_system(self.app, self.user_manager) if args.enable_assets:
register_assets_routes(self.app, self.user_manager)
else:
register_assets_routes(self.app)
asset_seeder.disable()
routes = web.RouteTableDef() routes = web.RouteTableDef()
self.routes = routes self.routes = routes
self.last_node_id = None self.last_node_id = None
@ -310,7 +312,7 @@ class PromptServer():
@routes.get("/") @routes.get("/")
async def get_root(request): async def get_root(request):
response = web.FileResponse(os.path.join(self.web_root, "index.html")) response = web.FileResponse(os.path.join(self.web_root, "index.html"))
response.headers['Cache-Control'] = 'no-cache' response.headers['Cache-Control'] = 'no-store, must-revalidate'
response.headers["Pragma"] = "no-cache" response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0" response.headers["Expires"] = "0"
return response return response
@ -419,7 +421,24 @@ class PromptServer():
with open(filepath, "wb") as f: with open(filepath, "wb") as f:
f.write(image.file.read()) f.write(image.file.read())
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
if args.enable_assets:
try:
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
resp["asset"] = {
"id": result.ref.id,
"name": result.ref.name,
"asset_hash": result.asset.hash,
"size": result.asset.size_bytes,
"mime_type": result.asset.mime_type,
"tags": result.tags,
}
except Exception:
logging.warning("Failed to register uploaded image as asset", exc_info=True)
return web.json_response(resp)
else: else:
return web.Response(status=400) return web.Response(status=400)
@ -479,6 +498,19 @@ class PromptServer():
async def view_image(request): async def view_image(request):
if "filename" in request.rel_url.query: if "filename" in request.rel_url.query:
filename = request.rel_url.query["filename"] filename = request.rel_url.query["filename"]
# The frontend's LoadImage combo widget uses asset_hash values
# (e.g. "blake3:...") as widget values. When litegraph renders the
# node preview, it constructs /view?filename=<asset_hash>, so this
# endpoint must resolve blake3 hashes to their on-disk file paths.
if filename.startswith("blake3:"):
owner_id = self.user_manager.get_request_user_id(request)
result = resolve_hash_to_path(filename, owner_id=owner_id)
if result is None:
return web.Response(status=404)
file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type
else:
resolved_content_type = None
filename, output_dir = folder_paths.annotated_filepath(filename) filename, output_dir = folder_paths.annotated_filepath(filename)
if not filename: if not filename:
@ -562,8 +594,13 @@ class PromptServer():
return web.Response(body=alpha_buffer.read(), content_type='image/png', return web.Response(body=alpha_buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""}) headers={"Content-Disposition": f"filename=\"{filename}\""})
else: else:
# Get content type from mimetype, defaulting to 'application/octet-stream' # Use the content type from asset resolution if available,
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream' # otherwise guess from the filename.
content_type = (
resolved_content_type
or mimetypes.guess_type(filename)[0]
or 'application/octet-stream'
)
# For security, force certain mimetypes to download instead of display # For security, force certain mimetypes to download instead of display
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}: if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
@ -697,10 +734,7 @@ class PromptServer():
@routes.get("/object_info") @routes.get("/object_info")
async def get_object_info(request): async def get_object_info(request):
try: asset_seeder.start(roots=("models", "input", "output"))
seed_assets(["models"])
except Exception as e:
logging.error(f"Failed to seed assets: {e}")
with folder_paths.cache_helper: with folder_paths.cache_helper:
out = {} out = {}
for x in nodes.NODE_CLASS_MAPPINGS: for x in nodes.NODE_CLASS_MAPPINGS:

View File

@ -0,0 +1,57 @@
"""Test that Alembic migrations run cleanly on a file-backed SQLite DB.
This catches problems like unnamed FK constraints that prevent batch-mode
drop_constraint from working on real SQLite files (see MB-2).
Migrations 0001 and 0002 are already shipped, so we only exercise
upgrade/downgrade for 0003+.
"""
import os
import pytest
from alembic import command
from alembic.config import Config
# Oldest shipped revision — we upgrade to here as a baseline and never
# downgrade past it.
_BASELINE = "0002_merge_to_asset_references"
def _make_config(db_path: str) -> Config:
root = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root, "alembic_db"))
cfg = Config(config_path)
cfg.set_main_option("script_location", scripts_path)
cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
return cfg
@pytest.fixture
def migration_db(tmp_path):
"""Yield an alembic Config pre-upgraded to the baseline revision."""
db_path = str(tmp_path / "test_migration.db")
cfg = _make_config(db_path)
command.upgrade(cfg, _BASELINE)
yield cfg
def test_upgrade_to_head(migration_db):
"""Upgrade from baseline to head must succeed on a file-backed DB."""
command.upgrade(migration_db, "head")
def test_downgrade_to_baseline(migration_db):
"""Upgrade to head then downgrade back to baseline."""
command.upgrade(migration_db, "head")
command.downgrade(migration_db, _BASELINE)
def test_upgrade_downgrade_cycle(migration_db):
"""Full cycle: upgrade → downgrade → upgrade again."""
command.upgrade(migration_db, "head")
command.downgrade(migration_db, _BASELINE)
command.upgrade(migration_db, "head")

View File

@ -108,7 +108,7 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest)
"main.py", "main.py",
f"--base-directory={str(comfy_tmp_base_dir)}", f"--base-directory={str(comfy_tmp_base_dir)}",
f"--database-url={db_url}", f"--database-url={db_url}",
"--disable-assets-autoscan", "--enable-assets",
"--listen", "--listen",
"127.0.0.1", "127.0.0.1",
"--port", "--port",
@ -212,7 +212,7 @@ def asset_factory(http: requests.Session, api_base: str):
for aid in created: for aid in created:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30) http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30)
@pytest.fixture @pytest.fixture
@ -258,14 +258,4 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str):
break break
for aid in ids: for aid in ids:
with contextlib.suppress(Exception): with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30) http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30)
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a fast sync/seed pass by calling the seed endpoint."""
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
time.sleep(0.2)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@ -0,0 +1,28 @@
"""Helper functions for assets integration tests."""
import time
import requests
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a synchronous sync/seed pass by calling the seed endpoint with wait=true.
Retries on 409 (already running) until the previous scan finishes.
"""
deadline = time.monotonic() + 60
while True:
r = session.post(
base_url + "/api/assets/seed?wait=true",
json={"roots": ["models", "input", "output"]},
timeout=60,
)
if r.status_code != 409:
assert r.status_code == 200, f"seed endpoint returned {r.status_code}: {r.text}"
return
if time.monotonic() > deadline:
raise TimeoutError("seed endpoint stuck in 409 (already running)")
time.sleep(0.25)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@ -0,0 +1,20 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import Base
@pytest.fixture
def session():
"""In-memory SQLite session for fast unit tests."""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
with Session(engine) as sess:
yield sess
@pytest.fixture(autouse=True)
def autoclean_unit_test_assets():
"""Override parent autouse fixture - query tests don't need server cleanup."""
yield

View File

@ -0,0 +1,187 @@
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.helpers import get_utc_now
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
get_asset_by_hash,
upsert_asset,
bulk_insert_assets,
update_asset_hash_and_mime,
)
class TestAssetExistsByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,expected",
[
(None, "nonexistent", False), # No asset exists
("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash
(None, "", False), # Null hash in DB doesn't match empty string
],
ids=["nonexistent", "existing", "null_hash_no_match"],
)
def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected):
if setup_hash is not None or query_hash == "":
asset = Asset(hash=setup_hash, size_bytes=100)
session.add(asset)
session.commit()
assert asset_exists_by_hash(session, asset_hash=query_hash) is expected
class TestGetAssetByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,should_find",
[
(None, "nonexistent", False),
("blake3:def456", "blake3:def456", True),
],
ids=["nonexistent", "existing"],
)
def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find):
if setup_hash is not None:
asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png")
session.add(asset)
session.commit()
result = get_asset_by_hash(session, asset_hash=query_hash)
if should_find:
assert result is not None
assert result.size_bytes == 200
assert result.mime_type == "image/png"
else:
assert result is None
class TestUpsertAsset:
@pytest.mark.parametrize(
"first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime",
[
# New asset creation
(None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"),
# Existing asset, same values - no update
(500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"),
# Existing asset with size 0, update with new values
(0, None, 2048, "image/png", False, True, 2048, "image/png"),
# Existing asset, second call with size 0 - no update
(1000, None, 0, None, False, False, 1000, None),
],
ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"],
)
def test_upsert_scenarios(
self,
session: Session,
first_size,
first_mime,
second_size,
second_mime,
expect_created,
expect_updated,
final_size,
final_mime,
):
asset_hash = f"blake3:test_{first_size}_{second_size}"
# First upsert (if first_size is not None, we're testing the second call)
if first_size is not None:
upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=first_size,
mime_type=first_mime,
)
session.commit()
# The upsert call we're testing
asset, created, updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=second_size,
mime_type=second_mime,
)
session.commit()
assert created is expect_created
assert updated is expect_updated
assert asset.size_bytes == final_size
assert asset.mime_type == final_mime
class TestBulkInsertAssets:
def test_inserts_multiple_assets(self, session: Session):
now = get_utc_now()
rows = [
{"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": now},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": now},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": now},
]
bulk_insert_assets(session, rows)
session.commit()
assets = session.query(Asset).all()
assert len(assets) == 3
hashes = {a.hash for a in assets}
assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"}
def test_empty_list_is_noop(self, session: Session):
bulk_insert_assets(session, [])
session.commit()
assert session.query(Asset).count() == 0
def test_handles_large_batch(self, session: Session):
"""Test chunking logic with more rows than MAX_BIND_PARAMS allows."""
now = get_utc_now()
rows = [
{"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": now}
for i in range(200)
]
bulk_insert_assets(session, rows)
session.commit()
assert session.query(Asset).count() == 200
class TestMimeTypeImmutability:
"""mime_type on Asset is write-once: set on first ingest, never overwritten."""
@pytest.mark.parametrize(
"initial_mime,second_mime,expected_mime",
[
("image/png", "image/jpeg", "image/png"),
(None, "image/png", "image/png"),
],
ids=["preserves_existing", "fills_null"],
)
def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime):
h = f"blake3:upsert_{initial_mime}_{second_mime}"
upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime)
session.commit()
asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime)
assert created is False
assert asset.mime_type == expected_mime
@pytest.mark.parametrize(
"initial_mime,update_mime,update_hash,expected_mime,expected_hash",
[
(None, "image/png", None, "image/png", "blake3:upd0"),
("image/png", "image/jpeg", None, "image/png", "blake3:upd1"),
("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"),
],
ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"],
)
def test_update_asset_hash_and_mime_immutability(
self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash,
):
h = expected_hash.removesuffix("_new")
asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime)
session.add(asset)
session.flush()
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash)
assert asset.mime_type == expected_mime
assert asset.hash == expected_hash

View File

@ -0,0 +1,520 @@
import time
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta
from app.assets.database.queries import (
reference_exists_for_asset_id,
get_reference_by_id,
insert_reference,
get_or_create_reference,
update_reference_timestamps,
list_references_page,
fetch_reference_asset_and_tags,
fetch_reference_and_asset,
update_reference_access_time,
set_reference_metadata,
delete_reference_by_id,
set_reference_preview,
bulk_insert_references_ignore_conflicts,
get_reference_ids_by_ids,
ensure_tags_exist,
add_tags_to_reference,
)
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
session.add(asset)
session.flush()
return asset
def _make_reference(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
class TestReferenceExistsForAssetId:
def test_returns_false_when_no_reference(self, session: Session):
asset = _make_asset(session, "hash1")
assert reference_exists_for_asset_id(session, asset_id=asset.id) is False
def test_returns_true_when_reference_exists(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset)
assert reference_exists_for_asset_id(session, asset_id=asset.id) is True
class TestGetReferenceById:
def test_returns_none_for_nonexistent(self, session: Session):
assert get_reference_by_id(session, reference_id="nonexistent") is None
def test_returns_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="myfile.txt")
result = get_reference_by_id(session, reference_id=ref.id)
assert result is not None
assert result.name == "myfile.txt"
class TestListReferencesPage:
def test_empty_db(self, session: Session):
refs, tag_map, total = list_references_page(session)
assert refs == []
assert tag_map == {}
assert total == 0
def test_returns_references_with_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="test.bin")
ensure_tags_exist(session, ["alpha", "beta"])
add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"])
session.commit()
refs, tag_map, total = list_references_page(session)
assert len(refs) == 1
assert refs[0].id == ref.id
assert set(tag_map[ref.id]) == {"alpha", "beta"}
assert total == 1
def test_name_contains_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="model_v1.safetensors")
_make_reference(session, asset, name="config.json")
session.commit()
refs, _, total = list_references_page(session, name_contains="model")
assert total == 1
assert refs[0].name == "model_v1.safetensors"
def test_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="public", owner_id="")
_make_reference(session, asset, name="private", owner_id="user1")
session.commit()
# Empty owner sees only public
refs, _, total = list_references_page(session, owner_id="")
assert total == 1
assert refs[0].name == "public"
# Owner sees both
refs, _, total = list_references_page(session, owner_id="user1")
assert total == 2
def test_include_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, name="tagged")
_make_reference(session, asset, name="untagged")
ensure_tags_exist(session, ["wanted"])
add_tags_to_reference(session, reference_id=ref1.id, tags=["wanted"])
session.commit()
refs, _, total = list_references_page(session, include_tags=["wanted"])
assert total == 1
assert refs[0].name == "tagged"
def test_exclude_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="keep")
ref_exclude = _make_reference(session, asset, name="exclude")
ensure_tags_exist(session, ["bad"])
add_tags_to_reference(session, reference_id=ref_exclude.id, tags=["bad"])
session.commit()
refs, _, total = list_references_page(session, exclude_tags=["bad"])
assert total == 1
assert refs[0].name == "keep"
def test_sorting(self, session: Session):
asset = _make_asset(session, "hash1", size=100)
asset2 = _make_asset(session, "hash2", size=500)
_make_reference(session, asset, name="small")
_make_reference(session, asset2, name="large")
session.commit()
refs, _, _ = list_references_page(session, sort="size", order="desc")
assert refs[0].name == "large"
refs, _, _ = list_references_page(session, sort="name", order="asc")
assert refs[0].name == "large"
class TestFetchReferenceAssetAndTags:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_reference_asset_and_tags(session, "nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="test.bin")
ensure_tags_exist(session, ["tag1"])
add_tags_to_reference(session, reference_id=ref.id, tags=["tag1"])
session.commit()
result = fetch_reference_asset_and_tags(session, ref.id)
assert result is not None
ret_ref, ret_asset, ret_tags = result
assert ret_ref.id == ref.id
assert ret_asset.id == asset.id
assert ret_tags == ["tag1"]
class TestFetchReferenceAndAsset:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_reference_and_asset(session, reference_id="nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
result = fetch_reference_and_asset(session, reference_id=ref.id)
assert result is not None
ret_ref, ret_asset = result
assert ret_ref.id == ref.id
assert ret_asset.id == asset.id
class TestUpdateReferenceAccessTime:
def test_updates_last_access_time(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
original_time = ref.last_access_time
session.commit()
import time
time.sleep(0.01)
update_reference_access_time(session, reference_id=ref.id)
session.commit()
session.refresh(ref)
assert ref.last_access_time > original_time
class TestDeleteReferenceById:
def test_deletes_existing(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
result = delete_reference_by_id(session, reference_id=ref.id, owner_id="")
assert result is True
assert get_reference_by_id(session, reference_id=ref.id) is None
def test_returns_false_for_nonexistent(self, session: Session):
result = delete_reference_by_id(session, reference_id="nonexistent", owner_id="")
assert result is False
def test_respects_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, owner_id="user1")
session.commit()
result = delete_reference_by_id(session, reference_id=ref.id, owner_id="user2")
assert result is False
assert get_reference_by_id(session, reference_id=ref.id) is not None
class TestSetReferencePreview:
def test_sets_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
session.commit()
set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id)
session.commit()
session.refresh(ref)
assert ref.preview_id == preview_ref.id
def test_clears_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
ref.preview_id = preview_ref.id
session.commit()
set_reference_preview(session, reference_id=ref.id, preview_reference_id=None)
session.commit()
session.refresh(ref)
assert ref.preview_id is None
def test_raises_for_nonexistent_reference(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None)
def test_raises_for_nonexistent_preview(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
with pytest.raises(ValueError, match="Preview AssetReference"):
set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent")
class TestInsertReference:
def test_creates_new_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref = insert_reference(
session, asset_id=asset.id, owner_id="user1", name="test.bin"
)
session.commit()
assert ref is not None
assert ref.name == "test.bin"
assert ref.owner_id == "user1"
def test_allows_duplicate_names(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = insert_reference(session, asset_id=asset.id, owner_id="user1", name="dup.bin")
session.commit()
# Duplicate names are now allowed
ref2 = insert_reference(
session, asset_id=asset.id, owner_id="user1", name="dup.bin"
)
session.commit()
assert ref1 is not None
assert ref2 is not None
assert ref1.id != ref2.id
class TestGetOrCreateReference:
def test_creates_new_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref, created = get_or_create_reference(
session, asset_id=asset.id, owner_id="user1", name="new.bin"
)
session.commit()
assert created is True
assert ref.name == "new.bin"
def test_always_creates_new_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref1, created1 = get_or_create_reference(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
# Duplicate names are allowed, so always creates new
ref2, created2 = get_or_create_reference(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
assert created1 is True
assert created2 is True
assert ref1.id != ref2.id
class TestUpdateReferenceTimestamps:
def test_updates_timestamps(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
original_updated_at = ref.updated_at
session.commit()
time.sleep(0.01)
update_reference_timestamps(session, ref)
session.commit()
session.refresh(ref)
assert ref.updated_at > original_updated_at
def test_updates_preview_id(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
session.commit()
update_reference_timestamps(session, ref, preview_id=preview_ref.id)
session.commit()
session.refresh(ref)
assert ref.preview_id == preview_ref.id
class TestSetReferenceMetadata:
def test_sets_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"key": "value"}
)
session.commit()
session.refresh(ref)
assert ref.user_metadata == {"key": "value"}
# Check metadata table
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
assert len(meta) == 1
assert meta[0].key == "key"
assert meta[0].val_str == "value"
def test_replaces_existing_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"old": "data"}
)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"new": "data"}
)
session.commit()
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
assert len(meta) == 1
assert meta[0].key == "new"
def test_clears_metadata_with_empty_dict(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"key": "value"}
)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={}
)
session.commit()
session.refresh(ref)
assert ref.user_metadata == {}
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
assert len(meta) == 0
def test_raises_for_nonexistent(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_reference_metadata(
session, reference_id="nonexistent", user_metadata={"key": "value"}
)
class TestBulkInsertReferencesIgnoreConflicts:
def test_inserts_multiple_references(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk1.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk2.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
refs = session.query(AssetReference).all()
assert len(refs) == 2
def test_allows_duplicate_names(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="existing.bin", owner_id="")
session.commit()
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "existing.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "new.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
# Duplicate names allowed, so all 3 rows exist
refs = session.query(AssetReference).all()
assert len(refs) == 3
def test_empty_list_is_noop(self, session: Session):
bulk_insert_references_ignore_conflicts(session, [])
assert session.query(AssetReference).count() == 0
class TestGetReferenceIdsByIds:
def test_returns_existing_ids(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, name="a.bin")
ref2 = _make_reference(session, asset, name="b.bin")
session.commit()
found = get_reference_ids_by_ids(session, [ref1.id, ref2.id, "nonexistent"])
assert found == {ref1.id, ref2.id}
def test_empty_list_returns_empty(self, session: Session):
found = get_reference_ids_by_ids(session, [])
assert found == set()

View File

@ -0,0 +1,499 @@
"""Tests for cache_state (AssetReference file path) query functions."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries import (
list_references_by_asset_id,
upsert_reference,
get_unreferenced_unhashed_asset_ids,
delete_assets_by_ids,
get_references_for_prefixes,
bulk_update_needs_verify,
delete_references_by_ids,
delete_orphaned_seed_asset,
bulk_insert_references_ignore_conflicts,
get_references_by_paths_and_asset_ids,
mark_references_missing_outside_prefixes,
restore_references_by_paths,
)
from app.assets.helpers import select_best_live_path, get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size)
session.add(asset)
session.flush()
return asset
def _make_reference(
session: Session,
asset: Asset,
file_path: str,
name: str = "test",
mtime_ns: int | None = None,
needs_verify: bool = False,
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
asset_id=asset.id,
file_path=file_path,
name=name,
mtime_ns=mtime_ns,
needs_verify=needs_verify,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
class TestListReferencesByAssetId:
def test_returns_empty_for_no_references(self, session: Session):
asset = _make_asset(session, "hash1")
refs = list_references_by_asset_id(session, asset_id=asset.id)
assert list(refs) == []
def test_returns_references_for_asset(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/path/a.bin", name="a")
_make_reference(session, asset, "/path/b.bin", name="b")
session.commit()
refs = list_references_by_asset_id(session, asset_id=asset.id)
paths = [r.file_path for r in refs]
assert set(paths) == {"/path/a.bin", "/path/b.bin"}
def test_does_not_return_other_assets_references(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_reference(session, asset1, "/path/asset1.bin", name="a1")
_make_reference(session, asset2, "/path/asset2.bin", name="a2")
session.commit()
refs = list_references_by_asset_id(session, asset_id=asset1.id)
paths = [r.file_path for r in refs]
assert paths == ["/path/asset1.bin"]
class TestSelectBestLivePath:
def test_returns_empty_for_empty_list(self):
result = select_best_live_path([])
assert result == ""
def test_returns_empty_when_no_files_exist(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, "/nonexistent/path.bin")
session.commit()
result = select_best_live_path([ref])
assert result == ""
def test_prefers_verified_path(self, session: Session, tmp_path):
"""needs_verify=False should be preferred."""
asset = _make_asset(session, "hash1")
verified_file = tmp_path / "verified.bin"
verified_file.write_bytes(b"data")
unverified_file = tmp_path / "unverified.bin"
unverified_file.write_bytes(b"data")
ref_verified = _make_reference(
session, asset, str(verified_file), name="verified", needs_verify=False
)
ref_unverified = _make_reference(
session, asset, str(unverified_file), name="unverified", needs_verify=True
)
session.commit()
refs = [ref_unverified, ref_verified]
result = select_best_live_path(refs)
assert result == str(verified_file)
def test_falls_back_to_existing_unverified(self, session: Session, tmp_path):
"""If all references need verification, return first existing path."""
asset = _make_asset(session, "hash1")
existing_file = tmp_path / "exists.bin"
existing_file.write_bytes(b"data")
ref = _make_reference(session, asset, str(existing_file), needs_verify=True)
session.commit()
result = select_best_live_path([ref])
assert result == str(existing_file)
class TestSelectBestLivePathWithMocking:
def test_handles_missing_file_path_attr(self):
"""Gracefully handle references with None file_path."""
class MockRef:
file_path = None
needs_verify = False
result = select_best_live_path([MockRef()])
assert result == ""
class TestUpsertReference:
@pytest.mark.parametrize(
"initial_mtime,second_mtime,expect_created,expect_updated,final_mtime",
[
# New reference creation
(None, 12345, True, False, 12345),
# Existing reference, same mtime - no update
(100, 100, False, False, 100),
# Existing reference, different mtime - update
(100, 200, False, True, 200),
],
ids=["new_reference", "existing_no_change", "existing_update_mtime"],
)
def test_upsert_scenarios(
self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime
):
asset = _make_asset(session, "hash1")
file_path = f"/path_{initial_mtime}_{second_mtime}.bin"
name = f"file_{initial_mtime}_{second_mtime}"
# Create initial reference if needed
if initial_mtime is not None:
upsert_reference(session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=initial_mtime)
session.commit()
# The upsert call we're testing
created, updated = upsert_reference(
session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=second_mtime
)
session.commit()
assert created is expect_created
assert updated is expect_updated
ref = session.query(AssetReference).filter_by(file_path=file_path).one()
assert ref.mtime_ns == final_mtime
def test_upsert_restores_missing_reference(self, session: Session):
"""Upserting a reference that was marked missing should restore it."""
asset = _make_asset(session, "hash1")
file_path = "/restored/file.bin"
ref = _make_reference(session, asset, file_path, mtime_ns=100)
ref.is_missing = True
session.commit()
created, updated = upsert_reference(
session, asset_id=asset.id, file_path=file_path, name="restored", mtime_ns=100
)
session.commit()
assert created is False
assert updated is True
restored_ref = session.query(AssetReference).filter_by(file_path=file_path).one()
assert restored_ref.is_missing is False
class TestRestoreReferencesByPaths:
def test_restores_missing_references(self, session: Session):
asset = _make_asset(session, "hash1")
missing_path = "/missing/file.bin"
active_path = "/active/file.bin"
missing_ref = _make_reference(session, asset, missing_path, name="missing")
missing_ref.is_missing = True
_make_reference(session, asset, active_path, name="active")
session.commit()
restored = restore_references_by_paths(session, [missing_path])
session.commit()
assert restored == 1
ref = session.query(AssetReference).filter_by(file_path=missing_path).one()
assert ref.is_missing is False
def test_empty_list_restores_nothing(self, session: Session):
restored = restore_references_by_paths(session, [])
assert restored == 0
class TestMarkReferencesMissingOutsidePrefixes:
def test_marks_references_missing_outside_prefixes(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
valid_dir = tmp_path / "valid"
valid_dir.mkdir()
invalid_dir = tmp_path / "invalid"
invalid_dir.mkdir()
valid_path = str(valid_dir / "file.bin")
invalid_path = str(invalid_dir / "file.bin")
_make_reference(session, asset, valid_path, name="valid")
_make_reference(session, asset, invalid_path, name="invalid")
session.commit()
marked = mark_references_missing_outside_prefixes(session, [str(valid_dir)])
session.commit()
assert marked == 1
all_refs = session.query(AssetReference).all()
assert len(all_refs) == 2
valid_ref = next(r for r in all_refs if r.file_path == valid_path)
invalid_ref = next(r for r in all_refs if r.file_path == invalid_path)
assert valid_ref.is_missing is False
assert invalid_ref.is_missing is True
def test_empty_prefixes_marks_nothing(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/some/path.bin")
session.commit()
marked = mark_references_missing_outside_prefixes(session, [])
assert marked == 0
class TestGetUnreferencedUnhashedAssetIds:
def test_returns_unreferenced_unhashed_assets(self, session: Session):
# Unhashed asset (hash=None) with no references (no file_path)
no_refs = _make_asset(session, hash_val=None)
# Unhashed asset with active reference (not unreferenced)
with_active_ref = _make_asset(session, hash_val=None)
_make_reference(session, with_active_ref, "/has/ref.bin", name="has_ref")
# Unhashed asset with only missing reference (should be unreferenced)
with_missing_ref = _make_asset(session, hash_val=None)
missing_ref = _make_reference(session, with_missing_ref, "/missing/ref.bin", name="missing_ref")
missing_ref.is_missing = True
# Regular asset (hash not None) - should not be returned
_make_asset(session, hash_val="blake3:regular")
session.commit()
unreferenced = get_unreferenced_unhashed_asset_ids(session)
assert no_refs.id in unreferenced
assert with_missing_ref.id in unreferenced
assert with_active_ref.id not in unreferenced
class TestDeleteAssetsByIds:
def test_deletes_assets_and_references(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/test/path.bin", name="test")
session.commit()
deleted = delete_assets_by_ids(session, [asset.id])
session.commit()
assert deleted == 1
assert session.query(Asset).count() == 0
assert session.query(AssetReference).count() == 0
def test_empty_list_deletes_nothing(self, session: Session):
_make_asset(session, "hash1")
session.commit()
deleted = delete_assets_by_ids(session, [])
assert deleted == 0
assert session.query(Asset).count() == 1
class TestGetReferencesForPrefixes:
def test_returns_references_matching_prefix(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
dir1 = tmp_path / "dir1"
dir1.mkdir()
dir2 = tmp_path / "dir2"
dir2.mkdir()
path1 = str(dir1 / "file.bin")
path2 = str(dir2 / "file.bin")
_make_reference(session, asset, path1, name="file1", mtime_ns=100)
_make_reference(session, asset, path2, name="file2", mtime_ns=200)
session.commit()
rows = get_references_for_prefixes(session, [str(dir1)])
assert len(rows) == 1
assert rows[0].file_path == path1
def test_empty_prefixes_returns_empty(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/some/path.bin")
session.commit()
rows = get_references_for_prefixes(session, [])
assert rows == []
class TestBulkSetNeedsVerify:
def test_sets_needs_verify_flag(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, "/path1.bin", needs_verify=False)
ref2 = _make_reference(session, asset, "/path2.bin", needs_verify=False)
session.commit()
updated = bulk_update_needs_verify(session, [ref1.id, ref2.id], True)
session.commit()
assert updated == 2
session.refresh(ref1)
session.refresh(ref2)
assert ref1.needs_verify is True
assert ref2.needs_verify is True
def test_empty_list_updates_nothing(self, session: Session):
updated = bulk_update_needs_verify(session, [], True)
assert updated == 0
class TestDeleteReferencesByIds:
def test_deletes_references_by_id(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, "/path1.bin")
_make_reference(session, asset, "/path2.bin")
session.commit()
deleted = delete_references_by_ids(session, [ref1.id])
session.commit()
assert deleted == 1
assert session.query(AssetReference).count() == 1
def test_empty_list_deletes_nothing(self, session: Session):
deleted = delete_references_by_ids(session, [])
assert deleted == 0
class TestDeleteOrphanedSeedAsset:
@pytest.mark.parametrize(
"create_asset,expected_deleted,expected_count",
[
(True, True, 0), # Existing asset gets deleted
(False, False, 0), # Nonexistent returns False
],
ids=["deletes_existing", "nonexistent_returns_false"],
)
def test_delete_orphaned_seed_asset(
self, session: Session, create_asset, expected_deleted, expected_count
):
asset_id = "nonexistent-id"
if create_asset:
asset = _make_asset(session, hash_val=None)
asset_id = asset.id
_make_reference(session, asset, "/test/path.bin", name="test")
session.commit()
deleted = delete_orphaned_seed_asset(session, asset_id)
if create_asset:
session.commit()
assert deleted is expected_deleted
assert session.query(Asset).count() == expected_count
class TestBulkInsertReferencesIgnoreConflicts:
def test_inserts_multiple_references(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
rows = [
{
"asset_id": asset.id,
"file_path": "/bulk1.bin",
"name": "bulk1",
"mtime_ns": 100,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"asset_id": asset.id,
"file_path": "/bulk2.bin",
"name": "bulk2",
"mtime_ns": 200,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetReference).count() == 2
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/existing.bin", mtime_ns=100)
session.commit()
now = get_utc_now()
rows = [
{
"asset_id": asset.id,
"file_path": "/existing.bin",
"name": "existing",
"mtime_ns": 999,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"asset_id": asset.id,
"file_path": "/new.bin",
"name": "new",
"mtime_ns": 200,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetReference).count() == 2
existing = session.query(AssetReference).filter_by(file_path="/existing.bin").one()
assert existing.mtime_ns == 100 # Original value preserved
def test_empty_list_is_noop(self, session: Session):
bulk_insert_references_ignore_conflicts(session, [])
assert session.query(AssetReference).count() == 0
class TestGetReferencesByPathsAndAssetIds:
def test_returns_matching_paths(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_reference(session, asset1, "/path1.bin")
_make_reference(session, asset2, "/path2.bin")
session.commit()
path_to_asset = {
"/path1.bin": asset1.id,
"/path2.bin": asset2.id,
}
winners = get_references_by_paths_and_asset_ids(session, path_to_asset)
assert winners == {"/path1.bin", "/path2.bin"}
def test_excludes_non_matching_asset_ids(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_reference(session, asset1, "/path1.bin")
session.commit()
# Path exists but with different asset_id
path_to_asset = {"/path1.bin": asset2.id}
winners = get_references_by_paths_and_asset_ids(session, path_to_asset)
assert winners == set()
def test_empty_dict_returns_empty(self, session: Session):
winners = get_references_by_paths_and_asset_ids(session, {})
assert winners == set()

View File

@ -0,0 +1,231 @@
"""Tests for metadata filtering logic in asset_reference queries."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta
from app.assets.database.queries import list_references_page
from app.assets.database.queries.asset_reference import convert_metadata_to_rows
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_reference(
session: Session,
asset: Asset,
name: str,
metadata: dict | None = None,
system_metadata: dict | None = None,
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
owner_id="",
name=name,
asset_id=asset.id,
user_metadata=metadata,
system_metadata=system_metadata,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
# Build merged projection: {**system_metadata, **user_metadata}
merged = {**(system_metadata or {}), **(metadata or {})}
if merged:
for key, val in merged.items():
for row in convert_metadata_to_rows(key, val):
meta_row = AssetReferenceMeta(
asset_reference_id=ref.id,
key=row["key"],
ordinal=row.get("ordinal", 0),
val_str=row.get("val_str"),
val_num=row.get("val_num"),
val_bool=row.get("val_bool"),
val_json=row.get("val_json"),
)
session.add(meta_row)
session.flush()
return ref
class TestMetadataFilterByType:
"""Table-driven tests for metadata filtering by different value types."""
@pytest.mark.parametrize(
"match_meta,nomatch_meta,filter_key,filter_val",
[
# String matching
({"category": "models"}, {"category": "images"}, "category", "models"),
# Integer matching
({"epoch": 5}, {"epoch": 10}, "epoch", 5),
# Float matching
({"score": 0.95}, {"score": 0.5}, "score", 0.95),
# Boolean True matching
({"enabled": True}, {"enabled": False}, "enabled", True),
# Boolean False matching
({"enabled": False}, {"enabled": True}, "enabled", False),
],
ids=["string", "int", "float", "bool_true", "bool_false"],
)
def test_filter_matches_correct_value(
self, session: Session, match_meta, nomatch_meta, filter_key, filter_val
):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "match", match_meta)
_make_reference(session, asset, "nomatch", nomatch_meta)
session.commit()
refs, _, total = list_references_page(
session, metadata_filter={filter_key: filter_val}
)
assert total == 1
assert refs[0].name == "match"
@pytest.mark.parametrize(
"stored_meta,filter_key,filter_val",
[
# String no match
({"category": "models"}, "category", "other"),
# Int no match
({"epoch": 5}, "epoch", 99),
# Float no match
({"score": 0.5}, "score", 0.99),
],
ids=["string_no_match", "int_no_match", "float_no_match"],
)
def test_filter_returns_empty_when_no_match(
self, session: Session, stored_meta, filter_key, filter_val
):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "item", stored_meta)
session.commit()
refs, _, total = list_references_page(
session, metadata_filter={filter_key: filter_val}
)
assert total == 0
class TestMetadataFilterNull:
"""Tests for null/missing key filtering."""
@pytest.mark.parametrize(
"match_name,match_meta,nomatch_name,nomatch_meta,filter_key",
[
# Null matches missing key
("missing_key", {}, "has_key", {"optional": "value"}, "optional"),
# Null matches explicit null
("explicit_null", {"nullable": None}, "has_value", {"nullable": "present"}, "nullable"),
],
ids=["missing_key", "explicit_null"],
)
def test_null_filter_matches(
self, session: Session, match_name, match_meta, nomatch_name, nomatch_meta, filter_key
):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, match_name, match_meta)
_make_reference(session, asset, nomatch_name, nomatch_meta)
session.commit()
refs, _, total = list_references_page(session, metadata_filter={filter_key: None})
assert total == 1
assert refs[0].name == match_name
class TestMetadataFilterList:
"""Tests for list-based (OR) filtering."""
def test_filter_by_list_matches_any(self, session: Session):
"""List values should match ANY of the values (OR)."""
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "cat_a", {"category": "a"})
_make_reference(session, asset, "cat_b", {"category": "b"})
_make_reference(session, asset, "cat_c", {"category": "c"})
session.commit()
refs, _, total = list_references_page(session, metadata_filter={"category": ["a", "b"]})
assert total == 2
names = {r.name for r in refs}
assert names == {"cat_a", "cat_b"}
class TestMetadataFilterMultipleKeys:
"""Tests for multiple filter keys (AND semantics)."""
def test_multiple_keys_must_all_match(self, session: Session):
"""Multiple keys should ALL match (AND)."""
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "match", {"type": "model", "version": 2})
_make_reference(session, asset, "wrong_type", {"type": "config", "version": 2})
_make_reference(session, asset, "wrong_version", {"type": "model", "version": 1})
session.commit()
refs, _, total = list_references_page(
session, metadata_filter={"type": "model", "version": 2}
)
assert total == 1
assert refs[0].name == "match"
class TestMetadataFilterEmptyDict:
"""Tests for empty filter behavior."""
def test_empty_filter_returns_all(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "a", {"key": "val"})
_make_reference(session, asset, "b", {})
session.commit()
refs, _, total = list_references_page(session, metadata_filter={})
assert total == 2
class TestSystemMetadataProjection:
"""Tests for system_metadata merging into the filter projection."""
def test_system_metadata_keys_are_filterable(self, session: Session):
"""system_metadata keys should appear in the merged projection."""
asset = _make_asset(session, "hash1")
_make_reference(
session, asset, "with_sys",
system_metadata={"source": "scanner"},
)
_make_reference(session, asset, "without_sys")
session.commit()
refs, _, total = list_references_page(
session, metadata_filter={"source": "scanner"}
)
assert total == 1
assert refs[0].name == "with_sys"
def test_user_metadata_overrides_system_metadata(self, session: Session):
"""user_metadata should win when both have the same key."""
asset = _make_asset(session, "hash1")
_make_reference(
session, asset, "overridden",
metadata={"origin": "user_upload"},
system_metadata={"origin": "auto_scan"},
)
session.commit()
# Should match the user value, not the system value
refs, _, total = list_references_page(
session, metadata_filter={"origin": "user_upload"}
)
assert total == 1
assert refs[0].name == "overridden"
# Should NOT match the system value (it was overridden)
refs, _, total = list_references_page(
session, metadata_filter={"origin": "auto_scan"}
)
assert total == 0

View File

@ -0,0 +1,366 @@
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, AssetReferenceMeta, Tag
from app.assets.database.queries import (
ensure_tags_exist,
get_reference_tags,
set_reference_tags,
add_tags_to_reference,
remove_tags_from_reference,
add_missing_tag_for_asset_id,
remove_missing_tag_for_asset_id,
list_tags_with_usage,
bulk_insert_tags_and_meta,
)
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str | None = None) -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_reference(session: Session, asset: Asset, name: str = "test", owner_id: str = "") -> AssetReference:
now = get_utc_now()
ref = AssetReference(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
class TestEnsureTagsExist:
def test_creates_new_tags(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta"], tag_type="user")
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_is_idempotent(self, session: Session):
ensure_tags_exist(session, ["alpha"], tag_type="user")
ensure_tags_exist(session, ["alpha"], tag_type="user")
session.commit()
assert session.query(Tag).count() == 1
def test_normalizes_tags(self, session: Session):
ensure_tags_exist(session, [" ALPHA ", "Beta", "alpha"])
session.commit()
tags = session.query(Tag).all()
assert {t.name for t in tags} == {"alpha", "beta"}
def test_empty_list_is_noop(self, session: Session):
ensure_tags_exist(session, [])
session.commit()
assert session.query(Tag).count() == 0
def test_tag_type_is_set(self, session: Session):
ensure_tags_exist(session, ["system-tag"], tag_type="system")
session.commit()
tag = session.query(Tag).filter_by(name="system-tag").one()
assert tag.tag_type == "system"
class TestGetReferenceTags:
def test_returns_empty_for_no_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
tags = get_reference_tags(session, reference_id=ref.id)
assert tags == []
def test_returns_tags_for_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["tag1", "tag2"])
session.add_all([
AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag1", origin="manual", added_at=get_utc_now()),
AssetReferenceTag(asset_reference_id=ref.id, tag_name="tag2", origin="manual", added_at=get_utc_now()),
])
session.flush()
tags = get_reference_tags(session, reference_id=ref.id)
assert set(tags) == {"tag1", "tag2"}
class TestSetReferenceTags:
def test_adds_new_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
result = set_reference_tags(session, reference_id=ref.id, tags=["a", "b"])
session.commit()
assert set(result.added) == {"a", "b"}
assert result.removed == []
assert set(result.total) == {"a", "b"}
def test_removes_old_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
set_reference_tags(session, reference_id=ref.id, tags=["a", "b", "c"])
result = set_reference_tags(session, reference_id=ref.id, tags=["a"])
session.commit()
assert result.added == []
assert set(result.removed) == {"b", "c"}
assert result.total == ["a"]
def test_replaces_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
set_reference_tags(session, reference_id=ref.id, tags=["a", "b"])
result = set_reference_tags(session, reference_id=ref.id, tags=["b", "c"])
session.commit()
assert result.added == ["c"]
assert result.removed == ["a"]
assert set(result.total) == {"b", "c"}
class TestAddTagsToReference:
def test_adds_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"])
session.commit()
assert set(result.added) == {"x", "y"}
assert result.already_present == []
def test_reports_already_present(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
add_tags_to_reference(session, reference_id=ref.id, tags=["x"])
result = add_tags_to_reference(session, reference_id=ref.id, tags=["x", "y"])
session.commit()
assert result.added == ["y"]
assert result.already_present == ["x"]
def test_raises_for_missing_reference(self, session: Session):
with pytest.raises(ValueError, match="not found"):
add_tags_to_reference(session, reference_id="nonexistent", tags=["x"])
class TestRemoveTagsFromReference:
def test_removes_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
add_tags_to_reference(session, reference_id=ref.id, tags=["a", "b", "c"])
result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "b"])
session.commit()
assert set(result.removed) == {"a", "b"}
assert result.not_present == []
assert result.total_tags == ["c"]
def test_reports_not_present(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
add_tags_to_reference(session, reference_id=ref.id, tags=["a"])
result = remove_tags_from_reference(session, reference_id=ref.id, tags=["a", "x"])
session.commit()
assert result.removed == ["a"]
assert result.not_present == ["x"]
def test_raises_for_missing_reference(self, session: Session):
with pytest.raises(ValueError, match="not found"):
remove_tags_from_reference(session, reference_id="nonexistent", tags=["x"])
class TestMissingTagFunctions:
def test_add_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_reference_tags(session, reference_id=ref.id)
assert "missing" in tags
def test_add_missing_tag_is_idempotent(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
add_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="missing").all()
assert len(links) == 1
def test_remove_missing_tag_for_asset_id(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["missing"], tag_type="system")
add_missing_tag_for_asset_id(session, asset_id=asset.id)
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
session.commit()
tags = get_reference_tags(session, reference_id=ref.id)
assert "missing" not in tags
class TestListTagsWithUsage:
def test_returns_tags_with_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
add_tags_to_reference(session, reference_id=ref.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict["used"] == 1
assert tag_dict["unused"] == 0
assert total == 2
def test_exclude_zero_counts(self, session: Session):
ensure_tags_exist(session, ["used", "unused"])
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
add_tags_to_reference(session, reference_id=ref.id, tags=["used"])
session.commit()
rows, total = list_tags_with_usage(session, include_zero=False)
tag_names = {name for name, _, _ in rows}
assert "used" in tag_names
assert "unused" not in tag_names
def test_prefix_filter(self, session: Session):
ensure_tags_exist(session, ["alpha", "beta", "alphabet"])
session.commit()
rows, total = list_tags_with_usage(session, prefix="alph")
tag_names = {name for name, _, _ in rows}
assert tag_names == {"alpha", "alphabet"}
def test_order_by_name(self, session: Session):
ensure_tags_exist(session, ["zebra", "alpha", "middle"])
session.commit()
rows, _ = list_tags_with_usage(session, order="name_asc")
names = [name for name, _, _ in rows]
assert names == ["alpha", "middle", "zebra"]
def test_owner_visibility(self, session: Session):
ensure_tags_exist(session, ["shared-tag", "owner-tag"])
asset = _make_asset(session, "hash1")
shared_ref = _make_reference(session, asset, name="shared", owner_id="")
owner_ref = _make_reference(session, asset, name="owned", owner_id="user1")
add_tags_to_reference(session, reference_id=shared_ref.id, tags=["shared-tag"])
add_tags_to_reference(session, reference_id=owner_ref.id, tags=["owner-tag"])
session.commit()
# Empty owner sees only shared
rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 0
# User1 sees both
rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False)
tag_dict = {name: count for name, _, count in rows}
assert tag_dict.get("shared-tag", 0) == 1
assert tag_dict.get("owner-tag", 0) == 1
class TestBulkInsertTagsAndMeta:
def test_inserts_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["bulk-tag1", "bulk-tag2"])
session.commit()
now = get_utc_now()
tag_rows = [
{"asset_reference_id": ref.id, "tag_name": "bulk-tag1", "origin": "manual", "added_at": now},
{"asset_reference_id": ref.id, "tag_name": "bulk-tag2", "origin": "manual", "added_at": now},
]
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
session.commit()
tags = get_reference_tags(session, reference_id=ref.id)
assert set(tags) == {"bulk-tag1", "bulk-tag2"}
def test_inserts_meta(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
meta_rows = [
{
"asset_reference_id": ref.id,
"key": "meta-key",
"ordinal": 0,
"val_str": "meta-value",
"val_num": None,
"val_bool": None,
"val_json": None,
},
]
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=meta_rows)
session.commit()
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
assert len(meta) == 1
assert meta[0].key == "meta-key"
assert meta[0].val_str == "meta-value"
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
ensure_tags_exist(session, ["existing-tag"])
add_tags_to_reference(session, reference_id=ref.id, tags=["existing-tag"])
session.commit()
now = get_utc_now()
tag_rows = [
{"asset_reference_id": ref.id, "tag_name": "existing-tag", "origin": "duplicate", "added_at": now},
]
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=[])
session.commit()
# Should still have only one tag link
links = session.query(AssetReferenceTag).filter_by(asset_reference_id=ref.id, tag_name="existing-tag").all()
assert len(links) == 1
# Origin should be original, not overwritten
assert links[0].origin == "manual"
def test_empty_lists_is_noop(self, session: Session):
bulk_insert_tags_and_meta(session, tag_rows=[], meta_rows=[])
assert session.query(AssetReferenceTag).count() == 0
assert session.query(AssetReferenceMeta).count() == 0

Some files were not shown because too many files have changed in this diff Show More