Merge pull request #12898 from pollockjj/pyisolate-pr-final
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

feat(isolation): upstream master sync + serializers, save nodes, and fixes
This commit is contained in:
John Pollock 2026-03-12 01:48:58 -05:00 committed by GitHub
commit 54461f9ecc
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
123 changed files with 14085 additions and 3229 deletions

1
.gitignore vendored
View File

@ -24,3 +24,4 @@ web_custom_versions/
openapi.yaml
filtered-openapi.yaml
uv.lock
.pyisolate_venvs/

View File

@ -0,0 +1,267 @@
"""
Merge AssetInfo and AssetCacheState into unified asset_references table.
This migration drops old tables and creates the new unified schema.
All existing data is discarded.
Revision ID: 0002_merge_to_asset_references
Revises: 0001_assets
Create Date: 2025-02-11
"""
from alembic import op
import sqlalchemy as sa
revision = "0002_merge_to_asset_references"
down_revision = "0001_assets"
branch_labels = None
depends_on = None
def upgrade() -> None:
# Drop old tables (order matters due to FK constraints)
op.drop_index("ix_asset_info_meta_key_val_bool", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_num", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key_val_str", table_name="asset_info_meta")
op.drop_index("ix_asset_info_meta_key", table_name="asset_info_meta")
op.drop_table("asset_info_meta")
op.drop_index("ix_asset_info_tags_asset_info_id", table_name="asset_info_tags")
op.drop_index("ix_asset_info_tags_tag_name", table_name="asset_info_tags")
op.drop_table("asset_info_tags")
op.drop_index("ix_asset_cache_state_asset_id", table_name="asset_cache_state")
op.drop_index("ix_asset_cache_state_file_path", table_name="asset_cache_state")
op.drop_table("asset_cache_state")
op.drop_index("ix_assets_info_owner_name", table_name="assets_info")
op.drop_index("ix_assets_info_last_access_time", table_name="assets_info")
op.drop_index("ix_assets_info_created_at", table_name="assets_info")
op.drop_index("ix_assets_info_name", table_name="assets_info")
op.drop_index("ix_assets_info_asset_id", table_name="assets_info")
op.drop_index("ix_assets_info_owner_id", table_name="assets_info")
op.drop_table("assets_info")
# Truncate assets table (cascades handled by dropping dependent tables first)
op.execute("DELETE FROM assets")
# Create asset_references table
op.create_table(
"asset_references",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column(
"asset_id",
sa.String(length=36),
sa.ForeignKey("assets.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("file_path", sa.Text(), nullable=True),
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column(
"needs_verify",
sa.Boolean(),
nullable=False,
server_default=sa.text("false"),
),
sa.Column(
"is_missing", sa.Boolean(), nullable=False, server_default=sa.text("false")
),
sa.Column("enrichment_level", sa.Integer(), nullable=False, server_default="0"),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column(
"preview_id",
sa.String(length=36),
sa.ForeignKey("assets.id", ondelete="SET NULL"),
nullable=True,
),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.Column("deleted_at", sa.DateTime(timezone=False), nullable=True),
sa.CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
),
sa.CheckConstraint(
"enrichment_level >= 0 AND enrichment_level <= 2",
name="ck_ar_enrichment_level_range",
),
)
op.create_index(
"uq_asset_references_file_path", "asset_references", ["file_path"], unique=True
)
op.create_index("ix_asset_references_asset_id", "asset_references", ["asset_id"])
op.create_index("ix_asset_references_owner_id", "asset_references", ["owner_id"])
op.create_index("ix_asset_references_name", "asset_references", ["name"])
op.create_index("ix_asset_references_is_missing", "asset_references", ["is_missing"])
op.create_index(
"ix_asset_references_enrichment_level", "asset_references", ["enrichment_level"]
)
op.create_index("ix_asset_references_created_at", "asset_references", ["created_at"])
op.create_index(
"ix_asset_references_last_access_time", "asset_references", ["last_access_time"]
)
op.create_index(
"ix_asset_references_owner_name", "asset_references", ["owner_id", "name"]
)
op.create_index("ix_asset_references_deleted_at", "asset_references", ["deleted_at"])
# Create asset_reference_tags table
op.create_table(
"asset_reference_tags",
sa.Column(
"asset_reference_id",
sa.String(length=36),
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column(
"tag_name",
sa.String(length=512),
sa.ForeignKey("tags.name", ondelete="RESTRICT"),
nullable=False,
),
sa.Column(
"origin", sa.String(length=32), nullable=False, server_default="manual"
),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint(
"asset_reference_id", "tag_name", name="pk_asset_reference_tags"
),
)
op.create_index(
"ix_asset_reference_tags_tag_name", "asset_reference_tags", ["tag_name"]
)
op.create_index(
"ix_asset_reference_tags_asset_reference_id",
"asset_reference_tags",
["asset_reference_id"],
)
# Create asset_reference_meta table
op.create_table(
"asset_reference_meta",
sa.Column(
"asset_reference_id",
sa.String(length=36),
sa.ForeignKey("asset_references.id", ondelete="CASCADE"),
nullable=False,
),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint(
"asset_reference_id", "key", "ordinal", name="pk_asset_reference_meta"
),
)
op.create_index("ix_asset_reference_meta_key", "asset_reference_meta", ["key"])
op.create_index(
"ix_asset_reference_meta_key_val_str", "asset_reference_meta", ["key", "val_str"]
)
op.create_index(
"ix_asset_reference_meta_key_val_num", "asset_reference_meta", ["key", "val_num"]
)
op.create_index(
"ix_asset_reference_meta_key_val_bool",
"asset_reference_meta",
["key", "val_bool"],
)
def downgrade() -> None:
"""Reverse 0002_merge_to_asset_references: drop new tables, recreate old schema.
NOTE: Data is not recoverable. The upgrade discards all rows from the old
tables and truncates assets. After downgrade the old schema will be empty.
A filesystem rescan will repopulate data once the older code is running.
"""
# Drop new tables (order matters due to FK constraints)
op.drop_index("ix_asset_reference_meta_key_val_bool", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key_val_num", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key_val_str", table_name="asset_reference_meta")
op.drop_index("ix_asset_reference_meta_key", table_name="asset_reference_meta")
op.drop_table("asset_reference_meta")
op.drop_index("ix_asset_reference_tags_asset_reference_id", table_name="asset_reference_tags")
op.drop_index("ix_asset_reference_tags_tag_name", table_name="asset_reference_tags")
op.drop_table("asset_reference_tags")
op.drop_index("ix_asset_references_deleted_at", table_name="asset_references")
op.drop_index("ix_asset_references_owner_name", table_name="asset_references")
op.drop_index("ix_asset_references_last_access_time", table_name="asset_references")
op.drop_index("ix_asset_references_created_at", table_name="asset_references")
op.drop_index("ix_asset_references_enrichment_level", table_name="asset_references")
op.drop_index("ix_asset_references_is_missing", table_name="asset_references")
op.drop_index("ix_asset_references_name", table_name="asset_references")
op.drop_index("ix_asset_references_owner_id", table_name="asset_references")
op.drop_index("ix_asset_references_asset_id", table_name="asset_references")
op.drop_index("uq_asset_references_file_path", table_name="asset_references")
op.drop_table("asset_references")
# Truncate assets (upgrade deleted all rows; downgrade starts fresh too)
op.execute("DELETE FROM assets")
# Recreate old tables from 0001_assets schema
op.create_table(
"assets_info",
sa.Column("id", sa.String(length=36), primary_key=True),
sa.Column("owner_id", sa.String(length=128), nullable=False, server_default=""),
sa.Column("name", sa.String(length=512), nullable=False),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False),
sa.Column("preview_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="SET NULL"), nullable=True),
sa.Column("user_metadata", sa.JSON(), nullable=True),
sa.Column("created_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("updated_at", sa.DateTime(timezone=False), nullable=False),
sa.Column("last_access_time", sa.DateTime(timezone=False), nullable=False),
sa.UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
)
op.create_index("ix_assets_info_owner_id", "assets_info", ["owner_id"])
op.create_index("ix_assets_info_asset_id", "assets_info", ["asset_id"])
op.create_index("ix_assets_info_name", "assets_info", ["name"])
op.create_index("ix_assets_info_created_at", "assets_info", ["created_at"])
op.create_index("ix_assets_info_last_access_time", "assets_info", ["last_access_time"])
op.create_index("ix_assets_info_owner_name", "assets_info", ["owner_id", "name"])
op.create_table(
"asset_cache_state",
sa.Column("id", sa.Integer(), primary_key=True, autoincrement=True),
sa.Column("asset_id", sa.String(length=36), sa.ForeignKey("assets.id", ondelete="CASCADE"), nullable=False),
sa.Column("file_path", sa.Text(), nullable=False),
sa.Column("mtime_ns", sa.BigInteger(), nullable=True),
sa.Column("needs_verify", sa.Boolean(), nullable=False, server_default=sa.text("false")),
sa.CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
sa.UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
)
op.create_index("ix_asset_cache_state_file_path", "asset_cache_state", ["file_path"])
op.create_index("ix_asset_cache_state_asset_id", "asset_cache_state", ["asset_id"])
op.create_table(
"asset_info_tags",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("tag_name", sa.String(length=512), sa.ForeignKey("tags.name", ondelete="RESTRICT"), nullable=False),
sa.Column("origin", sa.String(length=32), nullable=False, server_default="manual"),
sa.Column("added_at", sa.DateTime(timezone=False), nullable=False),
sa.PrimaryKeyConstraint("asset_info_id", "tag_name", name="pk_asset_info_tags"),
)
op.create_index("ix_asset_info_tags_tag_name", "asset_info_tags", ["tag_name"])
op.create_index("ix_asset_info_tags_asset_info_id", "asset_info_tags", ["asset_info_id"])
op.create_table(
"asset_info_meta",
sa.Column("asset_info_id", sa.String(length=36), sa.ForeignKey("assets_info.id", ondelete="CASCADE"), nullable=False),
sa.Column("key", sa.String(length=256), nullable=False),
sa.Column("ordinal", sa.Integer(), nullable=False, server_default="0"),
sa.Column("val_str", sa.String(length=2048), nullable=True),
sa.Column("val_num", sa.Numeric(38, 10), nullable=True),
sa.Column("val_bool", sa.Boolean(), nullable=True),
sa.Column("val_json", sa.JSON(), nullable=True),
sa.PrimaryKeyConstraint("asset_info_id", "key", "ordinal", name="pk_asset_info_meta"),
)
op.create_index("ix_asset_info_meta_key", "asset_info_meta", ["key"])
op.create_index("ix_asset_info_meta_key_val_str", "asset_info_meta", ["key", "val_str"])
op.create_index("ix_asset_info_meta_key_val_num", "asset_info_meta", ["key", "val_num"])
op.create_index("ix_asset_info_meta_key_val_bool", "asset_info_meta", ["key", "val_bool"])

File diff suppressed because it is too large Load Diff

View File

@ -1,6 +1,8 @@
import json
from dataclasses import dataclass
from typing import Any, Literal
from app.assets.helpers import validate_blake3_hash
from pydantic import (
BaseModel,
ConfigDict,
@ -10,6 +12,41 @@ from pydantic import (
model_validator,
)
class UploadError(Exception):
"""Error during upload parsing with HTTP status and code."""
def __init__(self, status: int, code: str, message: str):
super().__init__(message)
self.status = status
self.code = code
self.message = message
class AssetValidationError(Exception):
"""Validation error in asset processing (invalid tags, metadata, etc.)."""
def __init__(self, code: str, message: str):
super().__init__(message)
self.code = code
self.message = message
@dataclass
class ParsedUpload:
"""Result of parsing a multipart upload request."""
file_present: bool
file_written: int
file_client_name: str | None
tmp_path: str | None
tags_raw: list[str]
provided_name: str | None
user_metadata_raw: str | None
provided_hash: str | None
provided_hash_exists: bool | None
class ListAssetsQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
@ -21,7 +58,9 @@ class ListAssetsQuery(BaseModel):
limit: conint(ge=1, le=500) = 20
offset: conint(ge=0) = 0
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = "created_at"
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
"created_at"
)
order: Literal["asc", "desc"] = "desc"
@field_validator("include_tags", "exclude_tags", mode="before")
@ -61,7 +100,7 @@ class UpdateAssetBody(BaseModel):
user_metadata: dict[str, Any] | None = None
@model_validator(mode="after")
def _at_least_one(self):
def _validate_at_least_one_field(self):
if self.name is None and self.user_metadata is None:
raise ValueError("Provide at least one of: name, user_metadata.")
return self
@ -78,19 +117,11 @@ class CreateFromHashBody(BaseModel):
@field_validator("hash")
@classmethod
def _require_blake3(cls, v):
s = (v or "").strip().lower()
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return s
return validate_blake3_hash(v or "")
@field_validator("tags", mode="before")
@classmethod
def _tags_norm(cls, v):
def _normalize_tags_field(cls, v):
if v is None:
return []
if isinstance(v, list):
@ -154,15 +185,16 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel):
"""Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category from folder_paths.folder_names_and_paths
if root == 'models', second must be a valid category
- name: display name
- user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' provided by the client for validation / fast-path
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
Files created via this endpoint are stored on disk using the **content hash** as the filename stem
and the original extension is preserved when available.
Files are stored using the content hash as filename stem.
"""
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1)
@ -175,17 +207,10 @@ class UploadAssetSpec(BaseModel):
def _parse_hash(cls, v):
if v is None:
return None
s = str(v).strip().lower()
s = str(v).strip()
if not s:
return None
if ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if algo != "blake3":
raise ValueError("only canonical 'blake3:<hex>' is accepted here")
if not digest or any(c for c in digest if c not in "0123456789abcdef"):
raise ValueError("hash digest must be lowercase hex")
return f"{algo}:{digest}"
return validate_blake3_hash(s)
@field_validator("tags", mode="before")
@classmethod
@ -260,5 +285,7 @@ class UploadAssetSpec(BaseModel):
raise ValueError("first tag must be one of: models, input, output")
if root == "models":
if len(self.tags) < 2:
raise ValueError("models uploads require a category tag as the second tag")
raise ValueError(
"models uploads require a category tag as the second tag"
)
return self

View File

@ -19,7 +19,7 @@ class AssetSummary(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "updated_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None
@ -40,7 +40,7 @@ class AssetUpdated(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _ser_updated(self, v: datetime | None, _info):
def _serialize_updated_at(self, v: datetime | None, _info):
return v.isoformat() if v else None
@ -59,7 +59,7 @@ class AssetDetail(BaseModel):
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _ser_dt(self, v: datetime | None, _info):
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None

171
app/assets/api/upload.py Normal file
View File

@ -0,0 +1,171 @@
import logging
import os
import uuid
from typing import Callable
from aiohttp import web
import folder_paths
from app.assets.api.schemas_in import ParsedUpload, UploadError
from app.assets.helpers import validate_blake3_hash
def normalize_and_validate_hash(s: str) -> str:
"""Validate and normalize a hash string.
Returns canonical 'blake3:<hex>' or raises UploadError.
"""
try:
return validate_blake3_hash(s)
except ValueError:
raise UploadError(400, "INVALID_HASH", "hash must be like 'blake3:<hex>'")
async def parse_multipart_upload(
request: web.Request,
check_hash_exists: Callable[[str], bool],
) -> ParsedUpload:
"""
Parse a multipart/form-data upload request.
Args:
request: The aiohttp request
check_hash_exists: Callable(hash_str) -> bool to check if a hash exists
Returns:
ParsedUpload with parsed fields and temp file path
Raises:
UploadError: On validation or I/O errors
"""
if not (request.content_type or "").lower().startswith("multipart/"):
raise UploadError(
415, "UNSUPPORTED_MEDIA_TYPE", "Use multipart/form-data for uploads."
)
reader = await request.multipart()
file_present = False
file_client_name: str | None = None
tags_raw: list[str] = []
provided_name: str | None = None
user_metadata_raw: str | None = None
provided_hash: str | None = None
provided_hash_exists: bool | None = None
file_written = 0
tmp_path: str | None = None
while True:
field = await reader.next()
if field is None:
break
fname = getattr(field, "name", "") or ""
if fname == "hash":
try:
s = ((await field.text()) or "").strip().lower()
except Exception:
raise UploadError(
400, "INVALID_HASH", "hash must be like 'blake3:<hex>'"
)
if s:
provided_hash = normalize_and_validate_hash(s)
try:
provided_hash_exists = check_hash_exists(provided_hash)
except Exception as e:
logging.exception(
"check_hash_exists failed for hash=%s: %s", provided_hash, e
)
raise UploadError(
500,
"HASH_CHECK_FAILED",
"Backend error while checking asset hash.",
)
elif fname == "file":
file_present = True
file_client_name = (field.filename or "").strip()
if provided_hash and provided_hash_exists is True:
# Hash exists - drain file but don't write to disk
try:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
file_written += len(chunk)
except Exception:
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive uploaded file."
)
continue
uploads_root = os.path.join(folder_paths.get_temp_directory(), "uploads")
unique_dir = os.path.join(uploads_root, uuid.uuid4().hex)
os.makedirs(unique_dir, exist_ok=True)
tmp_path = os.path.join(unique_dir, ".upload.part")
try:
with open(tmp_path, "wb") as f:
while True:
chunk = await field.read_chunk(8 * 1024 * 1024)
if not chunk:
break
f.write(chunk)
file_written += len(chunk)
except Exception:
delete_temp_file_if_exists(tmp_path)
raise UploadError(
500, "UPLOAD_IO_ERROR", "Failed to receive and store uploaded file."
)
elif fname == "tags":
tags_raw.append((await field.text()) or "")
elif fname == "name":
provided_name = (await field.text()) or None
elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None
if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError(
400, "MISSING_FILE", "Form must include a 'file' part or a known 'hash'."
)
if (
file_present
and file_written == 0
and not (provided_hash and provided_hash_exists)
):
delete_temp_file_if_exists(tmp_path)
raise UploadError(400, "EMPTY_UPLOAD", "Uploaded file is empty.")
return ParsedUpload(
file_present=file_present,
file_written=file_written,
file_client_name=file_client_name,
tmp_path=tmp_path,
tags_raw=tags_raw,
provided_name=provided_name,
user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists,
)
def delete_temp_file_if_exists(tmp_path: str | None) -> None:
"""Safely remove a temp file and its parent directory if empty."""
if tmp_path:
try:
if os.path.exists(tmp_path):
os.remove(tmp_path)
except OSError as e:
logging.debug("Failed to delete temp file %s: %s", tmp_path, e)
try:
parent = os.path.dirname(tmp_path)
if parent and os.path.isdir(parent):
os.rmdir(parent) # only succeeds if empty
except OSError:
pass

View File

@ -1,204 +0,0 @@
import os
import uuid
import sqlalchemy
from typing import Iterable
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import utcnow
from app.assets.database.models import Asset, AssetCacheState, AssetInfo, AssetInfoTag, AssetInfoMeta
MAX_BIND_PARAMS = 800
def _chunk_rows(rows: list[dict], cols_per_row: int, max_bind_params: int) -> Iterable[list[dict]]:
if not rows:
return []
rows_per_stmt = max(1, max_bind_params // max(1, cols_per_row))
for i in range(0, len(rows), rows_per_stmt):
yield rows[i:i + rows_per_stmt]
def _iter_chunks(seq, n: int):
for i in range(0, len(seq), n):
yield seq[i:i + n]
def _rows_per_stmt(cols: int) -> int:
return max(1, MAX_BIND_PARAMS // max(1, cols))
def seed_from_paths_batch(
session: Session,
*,
specs: list[dict],
owner_id: str = "",
) -> dict:
"""Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
"""
if not specs:
return {"inserted_infos": 0, "won_states": 0, "lost_states": 0}
now = utcnow()
asset_rows: list[dict] = []
state_rows: list[dict] = []
path_to_asset: dict[str, str] = {}
asset_to_info: dict[str, dict] = {} # asset_id -> prepared info row
path_list: list[str] = []
for sp in specs:
ap = os.path.abspath(sp["abs_path"])
aid = str(uuid.uuid4())
iid = str(uuid.uuid4())
path_list.append(ap)
path_to_asset[ap] = aid
asset_rows.append(
{
"id": aid,
"hash": None,
"size_bytes": sp["size_bytes"],
"mime_type": None,
"created_at": now,
}
)
state_rows.append(
{
"asset_id": aid,
"file_path": ap,
"mtime_ns": sp["mtime_ns"],
}
)
asset_to_info[aid] = {
"id": iid,
"owner_id": owner_id,
"name": sp["info_name"],
"asset_id": aid,
"preview_id": None,
"user_metadata": {"filename": sp["fname"]} if sp["fname"] else None,
"created_at": now,
"updated_at": now,
"last_access_time": now,
"_tags": sp["tags"],
"_filename": sp["fname"],
}
# insert all seed Assets (hash=NULL)
ins_asset = sqlite.insert(Asset)
for chunk in _iter_chunks(asset_rows, _rows_per_stmt(5)):
session.execute(ins_asset, chunk)
# try to claim AssetCacheState (file_path)
# Insert with ON CONFLICT DO NOTHING, then query to find which paths were actually inserted
ins_state = (
sqlite.insert(AssetCacheState)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
for chunk in _iter_chunks(state_rows, _rows_per_stmt(3)):
session.execute(ins_state, chunk)
# Query to find which of our paths won (were actually inserted)
winners_by_path: set[str] = set()
for chunk in _iter_chunks(path_list, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetCacheState.file_path)
.where(AssetCacheState.file_path.in_(chunk))
.where(AssetCacheState.asset_id.in_([path_to_asset[p] for p in chunk]))
)
winners_by_path.update(result.scalars().all())
all_paths_set = set(path_list)
losers_by_path = all_paths_set - winners_by_path
lost_assets = [path_to_asset[p] for p in losers_by_path]
if lost_assets: # losers get their Asset removed
for id_chunk in _iter_chunks(lost_assets, MAX_BIND_PARAMS):
session.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(id_chunk)))
if not winners_by_path:
return {"inserted_infos": 0, "won_states": 0, "lost_states": len(losers_by_path)}
# insert AssetInfo only for winners
# Insert with ON CONFLICT DO NOTHING, then query to find which were actually inserted
winner_info_rows = [asset_to_info[path_to_asset[p]] for p in winners_by_path]
ins_info = (
sqlite.insert(AssetInfo)
.on_conflict_do_nothing(index_elements=[AssetInfo.asset_id, AssetInfo.owner_id, AssetInfo.name])
)
for chunk in _iter_chunks(winner_info_rows, _rows_per_stmt(9)):
session.execute(ins_info, chunk)
# Query to find which info rows were actually inserted (by matching our generated IDs)
all_info_ids = [row["id"] for row in winner_info_rows]
inserted_info_ids: set[str] = set()
for chunk in _iter_chunks(all_info_ids, MAX_BIND_PARAMS):
result = session.execute(
sqlalchemy.select(AssetInfo.id).where(AssetInfo.id.in_(chunk))
)
inserted_info_ids.update(result.scalars().all())
# build and insert tag + meta rows for the AssetInfo
tag_rows: list[dict] = []
meta_rows: list[dict] = []
if inserted_info_ids:
for row in winner_info_rows:
iid = row["id"]
if iid not in inserted_info_ids:
continue
for t in row["_tags"]:
tag_rows.append({
"asset_info_id": iid,
"tag_name": t,
"origin": "automatic",
"added_at": now,
})
if row["_filename"]:
meta_rows.append(
{
"asset_info_id": iid,
"key": "filename",
"ordinal": 0,
"val_str": row["_filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=meta_rows, max_bind_params=MAX_BIND_PARAMS)
return {
"inserted_infos": len(inserted_info_ids),
"won_states": len(winners_by_path),
"lost_states": len(losers_by_path),
}
def bulk_insert_tags_and_meta(
session: Session,
*,
tag_rows: list[dict],
meta_rows: list[dict],
max_bind_params: int,
) -> None:
"""Batch insert into asset_info_tags and asset_info_meta with ON CONFLICT DO NOTHING.
- tag_rows keys: asset_info_id, tag_name, origin, added_at
- meta_rows keys: asset_info_id, key, ordinal, val_str, val_num, val_bool, val_json
"""
if tag_rows:
ins_links = (
sqlite.insert(AssetInfoTag)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
for chunk in _chunk_rows(tag_rows, cols_per_row=4, max_bind_params=max_bind_params):
session.execute(ins_links, chunk)
if meta_rows:
ins_meta = (
sqlite.insert(AssetInfoMeta)
.on_conflict_do_nothing(
index_elements=[AssetInfoMeta.asset_info_id, AssetInfoMeta.key, AssetInfoMeta.ordinal]
)
)
for chunk in _chunk_rows(meta_rows, cols_per_row=7, max_bind_params=max_bind_params):
session.execute(ins_meta, chunk)

View File

@ -2,8 +2,8 @@ from __future__ import annotations
import uuid
from datetime import datetime
from typing import Any
from sqlalchemy import (
JSON,
BigInteger,
@ -16,102 +16,102 @@ from sqlalchemy import (
Numeric,
String,
Text,
UniqueConstraint,
)
from sqlalchemy.orm import Mapped, foreign, mapped_column, relationship
from app.assets.helpers import utcnow
from app.database.models import to_dict, Base
from app.assets.helpers import get_utc_now
from app.database.models import Base
class Asset(Base):
__tablename__ = "assets"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
hash: Mapped[str | None] = mapped_column(String(256), nullable=True)
size_bytes: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)
mime_type: Mapped[str | None] = mapped_column(String(255))
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
DateTime(timezone=False), nullable=False, default=get_utc_now
)
infos: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
references: Mapped[list[AssetReference]] = relationship(
"AssetReference",
back_populates="asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.asset_id),
foreign_keys=lambda: [AssetInfo.asset_id],
primaryjoin=lambda: Asset.id == foreign(AssetReference.asset_id),
foreign_keys=lambda: [AssetReference.asset_id],
cascade="all,delete-orphan",
passive_deletes=True,
)
preview_of: Mapped[list[AssetInfo]] = relationship(
"AssetInfo",
preview_of: Mapped[list[AssetReference]] = relationship(
"AssetReference",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetInfo.preview_id),
foreign_keys=lambda: [AssetInfo.preview_id],
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
foreign_keys=lambda: [AssetReference.preview_id],
viewonly=True,
)
cache_states: Mapped[list[AssetCacheState]] = relationship(
back_populates="asset",
cascade="all, delete-orphan",
passive_deletes=True,
)
__table_args__ = (
Index("uq_assets_hash", "hash", unique=True),
Index("ix_assets_mime_type", "mime_type"),
CheckConstraint("size_bytes >= 0", name="ck_assets_size_nonneg"),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
def __repr__(self) -> str:
return f"<Asset id={self.id} hash={(self.hash or '')[:12]}>"
class AssetCacheState(Base):
__tablename__ = "asset_cache_state"
class AssetReference(Base):
"""Unified model combining file cache state and user-facing metadata.
id: Mapped[int] = mapped_column(Integer, primary_key=True, autoincrement=True)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False)
file_path: Mapped[str] = mapped_column(Text, nullable=False)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
Each row represents either:
- A filesystem reference (file_path is set) with cache state
- An API-created reference (file_path is NULL) without cache state
"""
asset: Mapped[Asset] = relationship(back_populates="cache_states")
__tablename__ = "asset_references"
__table_args__ = (
Index("ix_asset_cache_state_file_path", "file_path"),
Index("ix_asset_cache_state_asset_id", "asset_id"),
CheckConstraint("(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_acs_mtime_nonneg"),
UniqueConstraint("file_path", name="uq_asset_cache_state_file_path"),
id: Mapped[str] = mapped_column(
String(36), primary_key=True, default=lambda: str(uuid.uuid4())
)
asset_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="CASCADE"), nullable=False
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
return to_dict(self, include_none=include_none)
# Cache state fields (from former AssetCacheState)
file_path: Mapped[str | None] = mapped_column(Text, nullable=True)
mtime_ns: Mapped[int | None] = mapped_column(BigInteger, nullable=True)
needs_verify: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
is_missing: Mapped[bool] = mapped_column(Boolean, nullable=False, default=False)
enrichment_level: Mapped[int] = mapped_column(Integer, nullable=False, default=0)
def __repr__(self) -> str:
return f"<AssetCacheState id={self.id} asset_id={self.asset_id} path={self.file_path!r}>"
class AssetInfo(Base):
__tablename__ = "assets_info"
id: Mapped[str] = mapped_column(String(36), primary_key=True, default=lambda: str(uuid.uuid4()))
# Info fields (from former AssetInfo)
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False)
asset_id: Mapped[str] = mapped_column(String(36), ForeignKey("assets.id", ondelete="RESTRICT"), nullable=False)
preview_id: Mapped[str | None] = mapped_column(String(36), ForeignKey("assets.id", ondelete="SET NULL"))
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(JSON(none_as_null=True))
created_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
updated_at: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
last_access_time: Mapped[datetime] = mapped_column(DateTime(timezone=False), nullable=False, default=utcnow)
preview_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="SET NULL")
)
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True)
)
created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
updated_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
last_access_time: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now
)
deleted_at: Mapped[datetime | None] = mapped_column(
DateTime(timezone=False), nullable=True, default=None
)
asset: Mapped[Asset] = relationship(
"Asset",
back_populates="infos",
back_populates="references",
foreign_keys=[asset_id],
lazy="selectin",
)
@ -121,51 +121,59 @@ class AssetInfo(Base):
foreign_keys=[preview_id],
)
metadata_entries: Mapped[list[AssetInfoMeta]] = relationship(
back_populates="asset_info",
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
back_populates="asset_reference",
cascade="all,delete-orphan",
passive_deletes=True,
)
tag_links: Mapped[list[AssetInfoTag]] = relationship(
back_populates="asset_info",
tag_links: Mapped[list[AssetReferenceTag]] = relationship(
back_populates="asset_reference",
cascade="all,delete-orphan",
passive_deletes=True,
overlaps="tags,asset_infos",
overlaps="tags,asset_references",
)
tags: Mapped[list[Tag]] = relationship(
secondary="asset_info_tags",
back_populates="asset_infos",
secondary="asset_reference_tags",
back_populates="asset_references",
lazy="selectin",
viewonly=True,
overlaps="tag_links,asset_info_links,asset_infos,tag",
overlaps="tag_links,asset_reference_links,asset_references,tag",
)
__table_args__ = (
UniqueConstraint("asset_id", "owner_id", "name", name="uq_assets_info_asset_owner_name"),
Index("ix_assets_info_owner_name", "owner_id", "name"),
Index("ix_assets_info_owner_id", "owner_id"),
Index("ix_assets_info_asset_id", "asset_id"),
Index("ix_assets_info_name", "name"),
Index("ix_assets_info_created_at", "created_at"),
Index("ix_assets_info_last_access_time", "last_access_time"),
Index("uq_asset_references_file_path", "file_path", unique=True),
Index("ix_asset_references_asset_id", "asset_id"),
Index("ix_asset_references_owner_id", "owner_id"),
Index("ix_asset_references_name", "name"),
Index("ix_asset_references_is_missing", "is_missing"),
Index("ix_asset_references_enrichment_level", "enrichment_level"),
Index("ix_asset_references_created_at", "created_at"),
Index("ix_asset_references_last_access_time", "last_access_time"),
Index("ix_asset_references_deleted_at", "deleted_at"),
Index("ix_asset_references_owner_name", "owner_id", "name"),
CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
),
CheckConstraint(
"enrichment_level >= 0 AND enrichment_level <= 2",
name="ck_ar_enrichment_level_range",
),
)
def to_dict(self, include_none: bool = False) -> dict[str, Any]:
data = to_dict(self, include_none=include_none)
data["tags"] = [t.name for t in self.tags]
return data
def __repr__(self) -> str:
return f"<AssetInfo id={self.id} name={self.name!r} asset_id={self.asset_id}>"
path_part = f" path={self.file_path!r}" if self.file_path else ""
return f"<AssetReference id={self.id} name={self.name!r}{path_part}>"
class AssetInfoMeta(Base):
__tablename__ = "asset_info_meta"
class AssetReferenceMeta(Base):
__tablename__ = "asset_reference_meta"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
asset_reference_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("asset_references.id", ondelete="CASCADE"),
primary_key=True,
)
key: Mapped[str] = mapped_column(String(256), primary_key=True)
ordinal: Mapped[int] = mapped_column(Integer, primary_key=True, default=0)
@ -175,36 +183,40 @@ class AssetInfoMeta(Base):
val_bool: Mapped[bool | None] = mapped_column(Boolean, nullable=True)
val_json: Mapped[Any | None] = mapped_column(JSON(none_as_null=True), nullable=True)
asset_info: Mapped[AssetInfo] = relationship(back_populates="metadata_entries")
asset_reference: Mapped[AssetReference] = relationship(
back_populates="metadata_entries"
)
__table_args__ = (
Index("ix_asset_info_meta_key", "key"),
Index("ix_asset_info_meta_key_val_str", "key", "val_str"),
Index("ix_asset_info_meta_key_val_num", "key", "val_num"),
Index("ix_asset_info_meta_key_val_bool", "key", "val_bool"),
Index("ix_asset_reference_meta_key", "key"),
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
)
class AssetInfoTag(Base):
__tablename__ = "asset_info_tags"
class AssetReferenceTag(Base):
__tablename__ = "asset_reference_tags"
asset_info_id: Mapped[str] = mapped_column(
String(36), ForeignKey("assets_info.id", ondelete="CASCADE"), primary_key=True
asset_reference_id: Mapped[str] = mapped_column(
String(36),
ForeignKey("asset_references.id", ondelete="CASCADE"),
primary_key=True,
)
tag_name: Mapped[str] = mapped_column(
String(512), ForeignKey("tags.name", ondelete="RESTRICT"), primary_key=True
)
origin: Mapped[str] = mapped_column(String(32), nullable=False, default="manual")
added_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=utcnow
DateTime(timezone=False), nullable=False, default=get_utc_now
)
asset_info: Mapped[AssetInfo] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_info_links")
asset_reference: Mapped[AssetReference] = relationship(back_populates="tag_links")
tag: Mapped[Tag] = relationship(back_populates="asset_reference_links")
__table_args__ = (
Index("ix_asset_info_tags_tag_name", "tag_name"),
Index("ix_asset_info_tags_asset_info_id", "asset_info_id"),
Index("ix_asset_reference_tags_tag_name", "tag_name"),
Index("ix_asset_reference_tags_asset_reference_id", "asset_reference_id"),
)
@ -214,20 +226,18 @@ class Tag(Base):
name: Mapped[str] = mapped_column(String(512), primary_key=True)
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
asset_info_links: Mapped[list[AssetInfoTag]] = relationship(
asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
back_populates="tag",
overlaps="asset_infos,tags",
overlaps="asset_references,tags",
)
asset_infos: Mapped[list[AssetInfo]] = relationship(
secondary="asset_info_tags",
asset_references: Mapped[list[AssetReference]] = relationship(
secondary="asset_reference_tags",
back_populates="tags",
viewonly=True,
overlaps="asset_info_links,tag_links,tags,asset_info",
overlaps="asset_reference_links,tag_links,tags,asset_reference",
)
__table_args__ = (
Index("ix_tags_tag_type", "tag_type"),
)
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
def __repr__(self) -> str:
return f"<Tag {self.name}>"

View File

@ -1,976 +0,0 @@
import os
import logging
import sqlalchemy as sa
from collections import defaultdict
from datetime import datetime
from typing import Iterable, Any
from sqlalchemy import select, delete, exists, func
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, contains_eager, noload
from app.assets.database.models import Asset, AssetInfo, AssetCacheState, AssetInfoMeta, AssetInfoTag, Tag
from app.assets.helpers import (
compute_relative_filename, escape_like_prefix, normalize_tags, project_kv, utcnow
)
from typing import Sequence
def visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads. Owner-less rows are visible to everyone."""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetInfo.owner_id == ""
return AssetInfo.owner_id.in_(["", owner_id])
def pick_best_live_path(states: Sequence[AssetCacheState]) -> str:
"""
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
alive = [s for s in states if getattr(s, "file_path", None) and os.path.isfile(s.file_path)]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetInfoTag.asset_info_id == AssetInfo.id)
& (AssetInfoTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_info_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetInfoMeta.asset_info_id == AssetInfo.id,
AssetInfoMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetInfoMeta.val_json.is_(None),
AssetInfoMeta.val_str.is_(None),
AssetInfoMeta.val_num.is_(None),
AssetInfoMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetInfoMeta.val_bool == bool(value))
if isinstance(value, (int, float)):
from decimal import Decimal
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetInfoMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetInfoMeta.val_str == value)
return _exists_for_pred(key, AssetInfoMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def asset_exists_by_hash(
session: Session,
*,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True)).select_from(Asset).where(Asset.hash == asset_hash).limit(1)
)
).first()
return row is not None
def asset_info_exists_for_asset_id(
session: Session,
*,
asset_id: str,
) -> bool:
q = (
select(sa.literal(True))
.select_from(AssetInfo)
.where(AssetInfo.asset_id == asset_id)
.limit(1)
)
return (session.execute(q)).first() is not None
def get_asset_by_hash(
session: Session,
*,
asset_hash: str,
) -> Asset | None:
return (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
def get_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
) -> AssetInfo | None:
return session.get(AssetInfo, asset_info_id)
def list_asset_infos_page(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> tuple[list[AssetInfo], dict[str, list[str]], int]:
base = (
select(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.options(contains_eager(AssetInfo.asset), noload(AssetInfo.tags))
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
base = base.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
base = apply_tag_filters(base, include_tags, exclude_tags)
base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower()
order = (order or "desc").lower()
sort_map = {
"name": AssetInfo.name,
"created_at": AssetInfo.created_at,
"updated_at": AssetInfo.updated_at,
"last_access_time": AssetInfo.last_access_time,
"size": Asset.size_bytes,
}
sort_col = sort_map.get(sort, AssetInfo.created_at)
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
base = base.order_by(sort_exp).limit(limit).offset(offset)
count_stmt = (
select(sa.func.count())
.select_from(AssetInfo)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(visible_owner_clause(owner_id))
)
if name_contains:
escaped, esc = escape_like_prefix(name_contains)
count_stmt = count_stmt.where(AssetInfo.name.ilike(f"%{escaped}%", escape=esc))
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int((session.execute(count_stmt)).scalar_one() or 0)
infos = (session.execute(base)).unique().scalars().all()
id_list: list[str] = [i.id for i in infos]
tag_map: dict[str, list[str]] = defaultdict(list)
if id_list:
rows = session.execute(
select(AssetInfoTag.asset_info_id, Tag.name)
.join(Tag, Tag.name == AssetInfoTag.tag_name)
.where(AssetInfoTag.asset_info_id.in_(id_list))
.order_by(AssetInfoTag.added_at)
)
for aid, tag_name in rows.all():
tag_map[aid].append(tag_name)
return infos, tag_map, total
def fetch_asset_info_asset_and_tags(
session: Session,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset, list[str]] | None:
stmt = (
select(AssetInfo, Asset, Tag.name)
.join(Asset, Asset.id == AssetInfo.asset_id)
.join(AssetInfoTag, AssetInfoTag.asset_info_id == AssetInfo.id, isouter=True)
.join(Tag, Tag.name == AssetInfoTag.tag_name, isouter=True)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.options(noload(AssetInfo.tags))
.order_by(Tag.name.asc())
)
rows = (session.execute(stmt)).all()
if not rows:
return None
first_info, first_asset, _ = rows[0]
tags: list[str] = []
seen: set[str] = set()
for _info, _asset, tag_name in rows:
if tag_name and tag_name not in seen:
seen.add(tag_name)
tags.append(tag_name)
return first_info, first_asset, tags
def fetch_asset_info_and_asset(
session: Session,
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[AssetInfo, Asset] | None:
stmt = (
select(AssetInfo, Asset)
.join(Asset, Asset.id == AssetInfo.asset_id)
.where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
.limit(1)
.options(noload(AssetInfo.tags))
)
row = session.execute(stmt)
pair = row.first()
if not pair:
return None
return pair[0], pair[1]
def list_cache_states_by_asset_id(
session: Session, *, asset_id: str
) -> Sequence[AssetCacheState]:
return (
session.execute(
select(AssetCacheState)
.where(AssetCacheState.asset_id == asset_id)
.order_by(AssetCacheState.id.asc())
)
).scalars().all()
def touch_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
ts: datetime | None = None,
only_if_newer: bool = True,
) -> None:
ts = ts or utcnow()
stmt = sa.update(AssetInfo).where(AssetInfo.id == asset_info_id)
if only_if_newer:
stmt = stmt.where(
sa.or_(AssetInfo.last_access_time.is_(None), AssetInfo.last_access_time < ts)
)
session.execute(stmt.values(last_access_time=ts))
def create_asset_info_for_existing_asset(
session: Session,
*,
asset_hash: str,
name: str,
user_metadata: dict | None = None,
tags: Sequence[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetInfo:
"""Create or return an existing AssetInfo for an Asset identified by asset_hash."""
now = utcnow()
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"Unknown asset hash {asset_hash}")
info = AssetInfo(
owner_id=owner_id,
name=name,
asset_id=asset.id,
preview_id=None,
created_at=now,
updated_at=now,
last_access_time=now,
)
try:
with session.begin_nested():
session.add(info)
session.flush()
except IntegrityError:
existing = (
session.execute(
select(AssetInfo)
.options(noload(AssetInfo.tags))
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == name,
AssetInfo.owner_id == owner_id,
)
.limit(1)
)
).unique().scalars().first()
if not existing:
raise RuntimeError("AssetInfo upsert failed to find existing row after conflict.")
return existing
# metadata["filename"] hack
new_meta = dict(user_metadata or {})
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=info.id,
user_metadata=new_meta,
)
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=info.id,
tags=tags,
origin=tag_origin,
)
return info
def set_asset_info_tags(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> dict:
desired = normalize_tags(tags)
current = set(
tag_name for (tag_name,) in (
session.execute(select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id))
).all()
)
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all([
AssetInfoTag(asset_info_id=asset_info_id, tag_name=t, origin=origin, added_at=utcnow())
for t in to_add
])
session.flush()
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(AssetInfoTag.asset_info_id == asset_info_id, AssetInfoTag.tag_name.in_(to_remove))
)
session.flush()
return {"added": to_add, "removed": to_remove, "total": desired}
def replace_asset_info_metadata_projection(
session: Session,
*,
asset_info_id: str,
user_metadata: dict | None = None,
) -> None:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info.user_metadata = user_metadata or {}
info.updated_at = utcnow()
session.flush()
session.execute(delete(AssetInfoMeta).where(AssetInfoMeta.asset_info_id == asset_info_id))
session.flush()
if not user_metadata:
return
rows: list[AssetInfoMeta] = []
for k, v in user_metadata.items():
for r in project_kv(k, v):
rows.append(
AssetInfoMeta(
asset_info_id=asset_info_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def ingest_fs_asset(
session: Session,
*,
asset_hash: str,
abs_path: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: dict | None = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> dict:
"""
Idempotently upsert:
- Asset by content hash (create if missing)
- AssetCacheState(file_path) pointing to asset_id
- Optionally AssetInfo + tag links and metadata projection
Returns flags and ids.
"""
locator = os.path.abspath(abs_path)
now = utcnow()
if preview_id:
if not session.get(Asset, preview_id):
preview_id = None
out: dict[str, Any] = {
"asset_created": False,
"asset_updated": False,
"state_created": False,
"state_updated": False,
"asset_info_id": None,
}
# 1) Asset by hash
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
).scalars().first()
if not asset:
vals = {
"hash": asset_hash,
"size_bytes": int(size_bytes),
"mime_type": mime_type,
"created_at": now,
}
res = session.execute(
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
if int(res.rowcount or 0) > 0:
out["asset_created"] = True
asset = (
session.execute(
select(Asset).where(Asset.hash == asset_hash).limit(1)
)
).scalars().first()
if not asset:
raise RuntimeError("Asset row not found after upsert.")
else:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
out["asset_updated"] = True
# 2) AssetCacheState upsert by file_path (unique)
vals = {
"asset_id": asset.id,
"file_path": locator,
"mtime_ns": int(mtime_ns),
}
ins = (
sqlite.insert(AssetCacheState)
.values(**vals)
.on_conflict_do_nothing(index_elements=[AssetCacheState.file_path])
)
res = session.execute(ins)
if int(res.rowcount or 0) > 0:
out["state_created"] = True
else:
upd = (
sa.update(AssetCacheState)
.where(AssetCacheState.file_path == locator)
.where(
sa.or_(
AssetCacheState.asset_id != asset.id,
AssetCacheState.mtime_ns.is_(None),
AssetCacheState.mtime_ns != int(mtime_ns),
)
)
.values(asset_id=asset.id, mtime_ns=int(mtime_ns))
)
res2 = session.execute(upd)
if int(res2.rowcount or 0) > 0:
out["state_updated"] = True
# 3) Optional AssetInfo + tags + metadata
if info_name:
try:
with session.begin_nested():
info = AssetInfo(
owner_id=owner_id,
name=info_name,
asset_id=asset.id,
preview_id=preview_id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(info)
session.flush()
out["asset_info_id"] = info.id
except IntegrityError:
pass
existing_info = (
session.execute(
select(AssetInfo)
.where(
AssetInfo.asset_id == asset.id,
AssetInfo.name == info_name,
(AssetInfo.owner_id == owner_id),
)
.limit(1)
)
).unique().scalar_one_or_none()
if not existing_info:
raise RuntimeError("Failed to update or insert AssetInfo.")
if preview_id and existing_info.preview_id != preview_id:
existing_info.preview_id = preview_id
existing_info.updated_at = now
if existing_info.last_access_time < now:
existing_info.last_access_time = now
session.flush()
out["asset_info_id"] = existing_info.id
norm = [t.strip().lower() for t in (tags or []) if (t or "").strip()]
if norm and out["asset_info_id"] is not None:
if not require_existing_tags:
ensure_tags_exist(session, norm, tag_type="user")
existing_tag_names = set(
name for (name,) in (session.execute(select(Tag.name).where(Tag.name.in_(norm)))).all()
)
missing = [t for t in norm if t not in existing_tag_names]
if missing and require_existing_tags:
raise ValueError(f"Unknown tags: {missing}")
existing_links = set(
tag_name
for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == out["asset_info_id"])
)
).all()
)
to_add = [t for t in norm if t in existing_tag_names and t not in existing_links]
if to_add:
session.add_all(
[
AssetInfoTag(
asset_info_id=out["asset_info_id"],
tag_name=t,
origin=tag_origin,
added_at=now,
)
for t in to_add
]
)
session.flush()
# metadata["filename"] hack
if out["asset_info_id"] is not None:
primary_path = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=asset.id))
computed_filename = compute_relative_filename(primary_path) if primary_path else None
current_meta = existing_info.user_metadata or {}
new_meta = dict(current_meta)
if user_metadata is not None:
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
replace_asset_info_metadata_projection(
session,
asset_info_id=out["asset_info_id"],
user_metadata=new_meta,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
return out
def update_asset_info_full(
session: Session,
*,
asset_info_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: dict | None = None,
tag_origin: str = "manual",
asset_info_row: Any = None,
) -> AssetInfo:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
else:
info = asset_info_row
touched = False
if name is not None and name != info.name:
info.name = name
touched = True
computed_filename = None
try:
p = pick_best_live_path(list_cache_states_by_asset_id(session, asset_id=info.asset_id))
if p:
computed_filename = compute_relative_filename(p)
except Exception:
computed_filename = None
if user_metadata is not None:
new_meta = dict(user_metadata)
if computed_filename:
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
else:
if computed_filename:
current_meta = info.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
new_meta["filename"] = computed_filename
replace_asset_info_metadata_projection(
session, asset_info_id=asset_info_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_asset_info_tags(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
info.updated_at = utcnow()
session.flush()
return info
def delete_asset_info_by_id(
session: Session,
*,
asset_info_id: str,
owner_id: str,
) -> bool:
stmt = sa.delete(AssetInfo).where(
AssetInfo.id == asset_info_id,
visible_owner_clause(owner_id),
)
return int((session.execute(stmt)).rowcount or 0) > 0
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetInfoTag.tag_name.label("tag_name"),
func.count(AssetInfoTag.asset_info_id).label("cnt"),
)
.select_from(AssetInfoTag)
.join(AssetInfo, AssetInfo.id == AssetInfoTag.asset_info_id)
.where(visible_owner_clause(owner_id))
.group_by(AssetInfoTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_like_prefix(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
total_q = total_q.where(
Tag.name.in_(select(AssetInfoTag.tag_name).group_by(AssetInfoTag.tag_name))
)
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_asset_tags(session: Session, *, asset_info_id: str) -> list[str]:
return [
tag_name for (tag_name,) in (
session.execute(
select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
]
def add_tags_to_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
asset_info_row: Any = None,
) -> dict:
if not asset_info_row:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"added": [], "already_present": [], "total_tags": total}
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetInfoTag(
asset_info_id=asset_info_id,
tag_name=t,
origin=origin,
added_at=utcnow(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_asset_tags(session, asset_info_id=asset_info_id))
return {
"added": sorted(((after - current) & want)),
"already_present": sorted(want & current),
"total_tags": sorted(after),
}
def remove_tags_from_asset_info(
session: Session,
*,
asset_info_id: str,
tags: Sequence[str],
) -> dict:
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": [], "not_present": [], "total_tags": total}
existing = {
tag_name
for (tag_name,) in (
session.execute(
sa.select(AssetInfoTag.tag_name).where(AssetInfoTag.asset_info_id == asset_info_id)
)
).all()
}
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetInfoTag)
.where(
AssetInfoTag.asset_info_id == asset_info_id,
AssetInfoTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_asset_tags(session, asset_info_id=asset_info_id)
return {"removed": to_remove, "not_present": not_present, "total_tags": total}
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sa.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)
def set_asset_info_preview(
session: Session,
*,
asset_info_id: str,
preview_asset_id: str | None = None,
) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
info = session.get(AssetInfo, asset_info_id)
if not info:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if preview_asset_id is None:
info.preview_id = None
else:
# validate preview asset exists
if not session.get(Asset, preview_asset_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found")
info.preview_id = preview_asset_id
info.updated_at = utcnow()
session.flush()

View File

@ -0,0 +1,121 @@
from app.assets.database.queries.asset import (
asset_exists_by_hash,
bulk_insert_assets,
get_asset_by_hash,
get_existing_asset_ids,
reassign_asset_references,
update_asset_hash_and_mime,
upsert_asset,
)
from app.assets.database.queries.asset_reference import (
CacheStateRow,
UnenrichedReferenceRow,
bulk_insert_references_ignore_conflicts,
bulk_update_enrichment_level,
bulk_update_is_missing,
bulk_update_needs_verify,
convert_metadata_to_rows,
delete_assets_by_ids,
delete_orphaned_seed_asset,
delete_reference_by_id,
delete_references_by_ids,
fetch_reference_and_asset,
fetch_reference_asset_and_tags,
get_or_create_reference,
get_reference_by_file_path,
get_reference_by_id,
get_reference_with_owner_check,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_references_for_prefixes,
get_unenriched_references,
get_unreferenced_unhashed_asset_ids,
insert_reference,
list_references_by_asset_id,
list_references_page,
mark_references_missing_outside_prefixes,
reference_exists_for_asset_id,
restore_references_by_paths,
set_reference_metadata,
set_reference_preview,
soft_delete_reference_by_id,
update_reference_access_time,
update_reference_name,
update_reference_timestamps,
update_reference_updated_at,
upsert_reference,
)
from app.assets.database.queries.tags import (
AddTagsResult,
RemoveTagsResult,
SetTagsResult,
add_missing_tag_for_asset_id,
add_tags_to_reference,
bulk_insert_tags_and_meta,
ensure_tags_exist,
get_reference_tags,
list_tags_with_usage,
remove_missing_tag_for_asset_id,
remove_tags_from_reference,
set_reference_tags,
validate_tags_exist,
)
__all__ = [
"AddTagsResult",
"CacheStateRow",
"RemoveTagsResult",
"SetTagsResult",
"UnenrichedReferenceRow",
"add_missing_tag_for_asset_id",
"add_tags_to_reference",
"asset_exists_by_hash",
"bulk_insert_assets",
"bulk_insert_references_ignore_conflicts",
"bulk_insert_tags_and_meta",
"bulk_update_enrichment_level",
"bulk_update_is_missing",
"bulk_update_needs_verify",
"convert_metadata_to_rows",
"delete_assets_by_ids",
"delete_orphaned_seed_asset",
"delete_reference_by_id",
"delete_references_by_ids",
"ensure_tags_exist",
"fetch_reference_and_asset",
"fetch_reference_asset_and_tags",
"get_asset_by_hash",
"get_existing_asset_ids",
"get_or_create_reference",
"get_reference_by_file_path",
"get_reference_by_id",
"get_reference_with_owner_check",
"get_reference_ids_by_ids",
"get_reference_tags",
"get_references_by_paths_and_asset_ids",
"get_references_for_prefixes",
"get_unenriched_references",
"get_unreferenced_unhashed_asset_ids",
"insert_reference",
"list_references_by_asset_id",
"list_references_page",
"list_tags_with_usage",
"mark_references_missing_outside_prefixes",
"reassign_asset_references",
"reference_exists_for_asset_id",
"remove_missing_tag_for_asset_id",
"remove_tags_from_reference",
"restore_references_by_paths",
"set_reference_metadata",
"set_reference_preview",
"soft_delete_reference_by_id",
"set_reference_tags",
"update_asset_hash_and_mime",
"update_reference_access_time",
"update_reference_name",
"update_reference_timestamps",
"update_reference_updated_at",
"upsert_asset",
"upsert_reference",
"validate_tags_exist",
]

View File

@ -0,0 +1,140 @@
import sqlalchemy as sa
from sqlalchemy import select
from sqlalchemy.dialects import sqlite
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries.common import MAX_BIND_PARAMS, calculate_rows_per_statement, iter_chunks
def asset_exists_by_hash(
session: Session,
asset_hash: str,
) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
row = (
session.execute(
select(sa.literal(True))
.select_from(Asset)
.where(Asset.hash == asset_hash)
.limit(1)
)
).first()
return row is not None
def get_asset_by_hash(
session: Session,
asset_hash: str,
) -> Asset | None:
return (
(session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1)))
.scalars()
.first()
)
def upsert_asset(
session: Session,
asset_hash: str,
size_bytes: int,
mime_type: str | None = None,
) -> tuple[Asset, bool, bool]:
"""Upsert an Asset by hash. Returns (asset, created, updated)."""
vals = {"hash": asset_hash, "size_bytes": int(size_bytes)}
if mime_type:
vals["mime_type"] = mime_type
ins = (
sqlite.insert(Asset)
.values(**vals)
.on_conflict_do_nothing(index_elements=[Asset.hash])
)
res = session.execute(ins)
created = int(res.rowcount or 0) > 0
asset = (
session.execute(select(Asset).where(Asset.hash == asset_hash).limit(1))
.scalars()
.first()
)
if not asset:
raise RuntimeError("Asset row not found after upsert.")
updated = False
if not created:
changed = False
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes)
changed = True
if mime_type and asset.mime_type != mime_type:
asset.mime_type = mime_type
changed = True
if changed:
updated = True
return asset, created, updated
def bulk_insert_assets(
session: Session,
rows: list[dict],
) -> None:
"""Bulk insert Asset rows with ON CONFLICT DO NOTHING on hash."""
if not rows:
return
ins = sqlite.insert(Asset).on_conflict_do_nothing(index_elements=[Asset.hash])
for chunk in iter_chunks(rows, calculate_rows_per_statement(5)):
session.execute(ins, chunk)
def get_existing_asset_ids(
session: Session,
asset_ids: list[str],
) -> set[str]:
"""Return the subset of asset_ids that exist in the database."""
if not asset_ids:
return set()
found: set[str] = set()
for chunk in iter_chunks(asset_ids, MAX_BIND_PARAMS):
rows = session.execute(
select(Asset.id).where(Asset.id.in_(chunk))
).fetchall()
found.update(row[0] for row in rows)
return found
def update_asset_hash_and_mime(
session: Session,
asset_id: str,
asset_hash: str | None = None,
mime_type: str | None = None,
) -> bool:
"""Update asset hash and/or mime_type. Returns True if asset was found."""
asset = session.get(Asset, asset_id)
if not asset:
return False
if asset_hash is not None:
asset.hash = asset_hash
if mime_type is not None:
asset.mime_type = mime_type
return True
def reassign_asset_references(
session: Session,
from_asset_id: str,
to_asset_id: str,
reference_id: str,
) -> None:
"""Reassign a reference from one asset to another.
Used when merging a stub asset into an existing asset with the same hash.
"""
ref = session.get(AssetReference, reference_id)
if ref and ref.asset_id == from_asset_id:
ref.asset_id = to_asset_id
session.flush()

File diff suppressed because it is too large Load Diff

View File

@ -0,0 +1,54 @@
"""Shared utilities for database query modules."""
import os
from typing import Iterable
import sqlalchemy as sa
from app.assets.database.models import AssetReference
from app.assets.helpers import escape_sql_like_string
MAX_BIND_PARAMS = 800
def calculate_rows_per_statement(cols: int) -> int:
"""Calculate how many rows can fit in one statement given column count."""
return max(1, MAX_BIND_PARAMS // max(1, cols))
def iter_chunks(seq, n: int):
"""Yield successive n-sized chunks from seq."""
for i in range(0, len(seq), n):
yield seq[i : i + n]
def iter_row_chunks(rows: list[dict], cols_per_row: int) -> Iterable[list[dict]]:
"""Yield chunks of rows sized to fit within bind param limits."""
if not rows:
return
yield from iter_chunks(rows, calculate_rows_per_statement(cols_per_row))
def build_visible_owner_clause(owner_id: str) -> sa.sql.ClauseElement:
"""Build owner visibility predicate for reads.
Owner-less rows are visible to everyone.
"""
owner_id = (owner_id or "").strip()
if owner_id == "":
return AssetReference.owner_id == ""
return AssetReference.owner_id.in_(["", owner_id])
def build_prefix_like_conditions(
prefixes: list[str],
) -> list[sa.sql.ColumnElement]:
"""Build LIKE conditions for matching file paths under directory prefixes."""
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
return conds

View File

@ -0,0 +1,356 @@
from dataclasses import dataclass
from typing import Iterable, Sequence
import sqlalchemy as sa
from sqlalchemy import delete, func, select
from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from app.assets.database.models import (
AssetReference,
AssetReferenceMeta,
AssetReferenceTag,
Tag,
)
from app.assets.database.queries.common import (
build_visible_owner_clause,
iter_row_chunks,
)
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
@dataclass(frozen=True)
class AddTagsResult:
added: list[str]
already_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class RemoveTagsResult:
removed: list[str]
not_present: list[str]
total_tags: list[str]
@dataclass(frozen=True)
class SetTagsResult:
added: list[str]
removed: list[str]
total: list[str]
def validate_tags_exist(session: Session, tags: list[str]) -> None:
"""Raise ValueError if any of the given tag names do not exist."""
existing_tag_names = set(
name
for (name,) in session.execute(select(Tag.name).where(Tag.name.in_(tags))).all()
)
missing = [t for t in tags if t not in existing_tag_names]
if missing:
raise ValueError(f"Unknown tags: {missing}")
def ensure_tags_exist(
session: Session, names: Iterable[str], tag_type: str = "user"
) -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
session.execute(ins)
def get_reference_tags(session: Session, reference_id: str) -> list[str]:
return [
tag_name
for (tag_name,) in (
session.execute(
select(AssetReferenceTag.tag_name).where(
AssetReferenceTag.asset_reference_id == reference_id
)
)
).all()
]
def set_reference_tags(
session: Session,
reference_id: str,
tags: Sequence[str],
origin: str = "manual",
) -> SetTagsResult:
desired = normalize_tags(tags)
current = set(get_reference_tags(session, reference_id))
to_add = [t for t in desired if t not in current]
to_remove = [t for t in current if t not in desired]
if to_add:
ensure_tags_exist(session, to_add, tag_type="user")
session.add_all(
[
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
if to_remove:
session.execute(
delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id == reference_id,
AssetReferenceTag.tag_name.in_(to_remove),
)
)
session.flush()
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
def add_tags_to_reference(
session: Session,
reference_id: str,
tags: Sequence[str],
origin: str = "manual",
create_if_missing: bool = True,
reference_row: AssetReference | None = None,
) -> AddTagsResult:
if not reference_row:
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return AddTagsResult(added=[], already_present=[], total_tags=total)
if create_if_missing:
ensure_tags_exist(session, norm, tag_type="user")
current = set(get_reference_tags(session, reference_id))
want = set(norm)
to_add = sorted(want - current)
if to_add:
with session.begin_nested() as nested:
try:
session.add_all(
[
AssetReferenceTag(
asset_reference_id=reference_id,
tag_name=t,
origin=origin,
added_at=get_utc_now(),
)
for t in to_add
]
)
session.flush()
except IntegrityError:
nested.rollback()
after = set(get_reference_tags(session, reference_id=reference_id))
return AddTagsResult(
added=sorted(((after - current) & want)),
already_present=sorted(want & current),
total_tags=sorted(after),
)
def remove_tags_from_reference(
session: Session,
reference_id: str,
tags: Sequence[str],
) -> RemoveTagsResult:
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
norm = normalize_tags(tags)
if not norm:
total = get_reference_tags(session, reference_id=reference_id)
return RemoveTagsResult(removed=[], not_present=[], total_tags=total)
existing = set(get_reference_tags(session, reference_id))
to_remove = sorted(set(t for t in norm if t in existing))
not_present = sorted(set(t for t in norm if t not in existing))
if to_remove:
session.execute(
delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id == reference_id,
AssetReferenceTag.tag_name.in_(to_remove),
)
)
session.flush()
total = get_reference_tags(session, reference_id=reference_id)
return RemoveTagsResult(removed=to_remove, not_present=not_present, total_tags=total)
def add_missing_tag_for_asset_id(
session: Session,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sa.select(
AssetReference.id.label("asset_reference_id"),
sa.literal("missing").label("tag_name"),
sa.literal(origin).label("origin"),
sa.literal(get_utc_now()).label("added_at"),
)
.where(AssetReference.asset_id == asset_id)
.where(
sa.not_(
sa.exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == "missing")
)
)
)
)
session.execute(
sqlite.insert(AssetReferenceTag)
.from_select(
["asset_reference_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(
index_elements=[
AssetReferenceTag.asset_reference_id,
AssetReferenceTag.tag_name,
]
)
)
def remove_missing_tag_for_asset_id(
session: Session,
asset_id: str,
) -> None:
session.execute(
sa.delete(AssetReferenceTag).where(
AssetReferenceTag.asset_reference_id.in_(
sa.select(AssetReference.id).where(AssetReference.asset_id == asset_id)
),
AssetReferenceTag.tag_name == "missing",
)
)
def list_tags_with_usage(
session: Session,
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
include_zero: bool = True,
order: str = "count_desc",
owner_id: str = "",
) -> tuple[list[tuple[str, str, int]], int]:
counts_sq = (
select(
AssetReferenceTag.tag_name.label("tag_name"),
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.select_from(AssetReferenceTag)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name)
.subquery()
)
q = (
select(
Tag.name,
Tag.tag_type,
func.coalesce(counts_sq.c.cnt, 0).label("count"),
)
.select_from(Tag)
.join(counts_sq, counts_sq.c.tag_name == Tag.name, isouter=True)
)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
q = q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
q = q.where(func.coalesce(counts_sq.c.cnt, 0) > 0)
if order == "name_asc":
q = q.order_by(Tag.name.asc())
else:
q = q.order_by(func.coalesce(counts_sq.c.cnt, 0).desc(), Tag.name.asc())
total_q = select(func.count()).select_from(Tag)
if prefix:
escaped, esc = escape_sql_like_string(prefix.strip().lower())
total_q = total_q.where(Tag.name.like(escaped + "%", escape=esc))
if not include_zero:
visible_tags_sq = (
select(AssetReferenceTag.tag_name)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name)
)
total_q = total_q.where(Tag.name.in_(visible_tags_sq))
rows = (session.execute(q.limit(limit).offset(offset))).all()
total = (session.execute(total_q)).scalar_one()
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
return rows_norm, int(total or 0)
def bulk_insert_tags_and_meta(
session: Session,
tag_rows: list[dict],
meta_rows: list[dict],
) -> None:
"""Batch insert into asset_reference_tags and asset_reference_meta.
Uses ON CONFLICT DO NOTHING.
Args:
session: Database session
tag_rows: Dicts with: asset_reference_id, tag_name, origin, added_at
meta_rows: Dicts with: asset_reference_id, key, ordinal, val_*
"""
if tag_rows:
ins_tags = sqlite.insert(AssetReferenceTag).on_conflict_do_nothing(
index_elements=[
AssetReferenceTag.asset_reference_id,
AssetReferenceTag.tag_name,
]
)
for chunk in iter_row_chunks(tag_rows, cols_per_row=4):
session.execute(ins_tags, chunk)
if meta_rows:
ins_meta = sqlite.insert(AssetReferenceMeta).on_conflict_do_nothing(
index_elements=[
AssetReferenceMeta.asset_reference_id,
AssetReferenceMeta.key,
AssetReferenceMeta.ordinal,
]
)
for chunk in iter_row_chunks(meta_rows, cols_per_row=7):
session.execute(ins_meta, chunk)

View File

@ -1,62 +0,0 @@
from typing import Iterable
import sqlalchemy
from sqlalchemy.orm import Session
from sqlalchemy.dialects import sqlite
from app.assets.helpers import normalize_tags, utcnow
from app.assets.database.models import Tag, AssetInfoTag, AssetInfo
def ensure_tags_exist(session: Session, names: Iterable[str], tag_type: str = "user") -> None:
wanted = normalize_tags(list(names))
if not wanted:
return
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
ins = (
sqlite.insert(Tag)
.values(rows)
.on_conflict_do_nothing(index_elements=[Tag.name])
)
return session.execute(ins)
def add_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
origin: str = "automatic",
) -> None:
select_rows = (
sqlalchemy.select(
AssetInfo.id.label("asset_info_id"),
sqlalchemy.literal("missing").label("tag_name"),
sqlalchemy.literal(origin).label("origin"),
sqlalchemy.literal(utcnow()).label("added_at"),
)
.where(AssetInfo.asset_id == asset_id)
.where(
sqlalchemy.not_(
sqlalchemy.exists().where((AssetInfoTag.asset_info_id == AssetInfo.id) & (AssetInfoTag.tag_name == "missing"))
)
)
)
session.execute(
sqlite.insert(AssetInfoTag)
.from_select(
["asset_info_id", "tag_name", "origin", "added_at"],
select_rows,
)
.on_conflict_do_nothing(index_elements=[AssetInfoTag.asset_info_id, AssetInfoTag.tag_name])
)
def remove_missing_tag_for_asset_id(
session: Session,
*,
asset_id: str,
) -> None:
session.execute(
sqlalchemy.delete(AssetInfoTag).where(
AssetInfoTag.asset_info_id.in_(sqlalchemy.select(AssetInfo.id).where(AssetInfo.asset_id == asset_id)),
AssetInfoTag.tag_name == "missing",
)
)

View File

@ -1,75 +0,0 @@
from blake3 import blake3
from typing import IO
import os
import asyncio
DEFAULT_CHUNK = 8 * 1024 *1024 # 8MB
# NOTE: this allows hashing different representations of a file-like object
def blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""
Returns a BLAKE3 hex digest for ``fp``, which may be:
- a filename (str/bytes) or PathLike
- an open binary file object
If ``fp`` is a file object, it must be opened in **binary** mode and support
``read``, ``seek``, and ``tell``. The function will seek to the start before
reading and will attempt to restore the original position afterward.
"""
# duck typing to check if input is a file-like object
if hasattr(fp, "read"):
return _hash_file_obj(fp, chunk_size)
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
async def blake3_hash_async(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
) -> str:
"""Async wrapper for ``blake3_hash_sync``.
Uses a worker thread so the event loop remains responsive.
"""
# If it is a path, open inside the worker thread to keep I/O off the loop.
if hasattr(fp, "read"):
return await asyncio.to_thread(blake3_hash, fp, chunk_size)
def _worker() -> str:
with open(os.fspath(fp), "rb") as f:
return _hash_file_obj(f, chunk_size)
return await asyncio.to_thread(_worker)
def _hash_file_obj(file_obj: IO, chunk_size: int = DEFAULT_CHUNK) -> str:
"""
Hash an already-open binary file object by streaming in chunks.
- Seeks to the beginning before reading (if supported).
- Restores the original position afterward (if tell/seek are supported).
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
# in case file object is already open and not at the beginning, track so can be restored after hashing
orig_pos = file_obj.tell()
try:
# seek to the beginning before reading
if orig_pos != 0:
file_obj.seek(0)
h = blake3()
while True:
chunk = file_obj.read(chunk_size)
if not chunk:
break
h.update(chunk)
return h.hexdigest()
finally:
# restore original position in file object, if needed
if orig_pos != 0:
file_obj.seek(orig_pos)

View File

@ -1,226 +1,42 @@
import contextlib
import os
from decimal import Decimal
from aiohttp import web
from datetime import datetime, timezone
from pathlib import Path
from typing import Literal, Any
import folder_paths
from typing import Sequence
RootType = Literal["models", "input", "output"]
ALLOWED_ROOTS: tuple[RootType, ...] = ("models", "input", "output")
def get_query_dict(request: web.Request) -> dict[str, Any]:
def select_best_live_path(states: Sequence) -> str:
"""
Gets a dictionary of query parameters from the request.
'request.query' is a MultiMapping[str], needs to be converted to a dictionary to be validated by Pydantic.
Return the best on-disk path among cache states:
1) Prefer a path that exists with needs_verify == False (already verified).
2) Otherwise, pick the first path that exists.
3) Otherwise return empty string.
"""
query_dict = {
key: request.query.getall(key) if len(request.query.getall(key)) > 1 else request.query.get(key)
for key in request.query.keys()
}
return query_dict
alive = [
s
for s in states
if getattr(s, "file_path", None) and os.path.isfile(s.file_path)
]
if not alive:
return ""
for s in alive:
if not getattr(s, "needs_verify", False):
return s.file_path
return alive[0].file_path
def list_tree(base_dir: str) -> list[str]:
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
for dirpath, _subdirs, filenames in os.walk(base_abs, topdown=True, followlinks=False):
for name in filenames:
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out
def prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def escape_sql_like_string(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char in a LIKE prefix.
def escape_like_prefix(s: str, escape: str = "!") -> tuple[str, str]:
"""Escapes %, _ and the escape char itself in a LIKE prefix.
Returns (escaped_prefix, escape_char). Caller should append '%' and pass escape=escape_char to .like().
Returns (escaped_prefix, escape_char).
"""
s = s.replace(escape, escape + escape) # escape the escape char first
s = s.replace("%", escape + "%").replace("_", escape + "_") # escape LIKE wildcards
return s, escape
def fast_asset_file_check(
*,
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
if mtime_db is None:
return False
actual_mtime_ns = getattr(stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000))
if int(mtime_db) != int(actual_mtime_ns):
return False
sz = int(size_db or 0)
if sz > 0:
return int(stat_result.st_size) == sz
return True
def utcnow() -> datetime:
def get_utc_now() -> datetime:
"""Naive UTC timestamp (no tzinfo). We always treat DB datetimes as UTC."""
return datetime.now(timezone.utc).replace(tzinfo=None)
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build a list of (folder_name, base_paths[]) categories that are configured for model locations.
We trust `folder_paths.folder_names_and_paths` and include a category if
*any* of its base paths lies under the Comfy `models_dir`.
"""
targets: list[tuple[str, list[str]]] = []
models_root = os.path.abspath(folder_paths.models_dir)
for name, values in folder_paths.folder_names_and_paths.items():
paths, _exts = values[0], values[1] # NOTE: this prevents nodepacks that hackily edit folder_... from breaking ComfyUI
if any(os.path.abspath(p).startswith(models_root + os.sep) for p in paths):
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
root = tags[0]
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
else:
base_dir = os.path.abspath(
folder_paths.get_input_directory() if root == "input" else folder_paths.get_output_directory()
)
raw_subdirs = tags[1:]
for i in raw_subdirs:
if i in (".", ".."):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def ensure_within_base(candidate: str, base: str) -> None:
cand_abs = os.path.abspath(candidate)
base_abs = os.path.abspath(base)
try:
if os.path.commonpath([cand_abs, base_abs]) != base_abs:
raise ValueError("destination escapes base directory")
except Exception:
raise ValueError("invalid destination path")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
NOTE: this is a temporary helper, used only for initializing metadata["filename"] field.
"""
try:
root_category, rel_path = get_relative_to_root_category_path_of_asset(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_relative_to_root_category_path_of_asset(file_path: str) -> tuple[Literal["input", "output", "models"], str]:
"""Given an absolute or relative file path, determine which root category the path belongs to:
- 'input' if the file resides under `folder_paths.get_input_directory()`
- 'output' if the file resides under `folder_paths.get_output_directory()`
- 'models' if the file resides under any base path of categories returned by `get_comfy_models_folders()`
Returns:
(root_category, relative_path_inside_that_root)
For 'models', the relative path is prefixed with the category name:
e.g. ('models', 'vae/test/sub/ae.safetensors')
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
fp_abs = os.path.abspath(file_path)
def _is_within(child: str, parent: str) -> bool:
try:
return os.path.commonpath([child, parent]) == parent
except Exception:
return False
def _rel(child: str, parent: str) -> str:
return os.path.relpath(os.path.join(os.sep, os.path.relpath(child, parent)), os.sep)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _is_within(fp_abs, input_base):
return "input", _rel(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _is_within(fp_abs, output_base):
return "output", _rel(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _rel(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(f"Path is not within input, output, or configured model bases: {file_path}")
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return a tuple (name, tags) derived from a filesystem path.
Semantics:
- Root category is determined by `get_relative_to_root_category_path_of_asset`.
- The returned `name` is the base filename with extension from the relative path.
- The returned `tags` are:
[root_category] + parent folders of the relative path (in order)
For 'models', this means:
file '/.../ModelsDir/vae/test_tag/ae.safetensors'
-> root_category='models', some_path='vae/test_tag/ae.safetensors'
-> name='ae.safetensors', tags=['models', 'vae', 'test_tag']
Raises:
ValueError: if the path does not belong to input, output, or configured model bases.
"""
root_category, some_path = get_relative_to_root_category_path_of_asset(file_path)
p = Path(some_path)
parent_parts = [part for part in p.parent.parts if part not in (".", "..", p.anchor)]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))
def normalize_tags(tags: list[str] | None) -> list[str]:
"""
@ -228,85 +44,22 @@ def normalize_tags(tags: list[str] | None) -> list[str]:
- Stripping whitespace and converting to lowercase.
- Removing duplicates.
"""
return [t.strip().lower() for t in (tags or []) if (t or "").strip()]
return list(dict.fromkeys(t.strip().lower() for t in (tags or []) if (t or "").strip()))
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
abs_path = os.path.abspath(abs_path)
allowed = False
for b in bases:
base_abs = os.path.abspath(b)
with contextlib.suppress(Exception):
if os.path.commonpath([abs_path, base_abs]) == base_abs:
allowed = True
break
if allowed:
out.append(abs_path)
return out
def is_scalar(v):
if v is None:
return True
if isinstance(v, bool):
return True
if isinstance(v, (int, float, Decimal, str)):
return True
return False
def validate_blake3_hash(s: str) -> str:
"""Validate and normalize a blake3 hash string.
def project_kv(key: str, value):
Returns canonical 'blake3:<hex>' or raises ValueError.
"""
Turn a metadata key/value into typed projection rows.
Returns list[dict] with keys:
key, ordinal, and one of val_str / val_num / val_bool / val_json (others None)
"""
rows: list[dict] = []
def _null_row(ordinal: int) -> dict:
return {
"key": key, "ordinal": ordinal,
"val_str": None, "val_num": None, "val_bool": None, "val_json": None
}
if value is None:
rows.append(_null_row(0))
return rows
if is_scalar(value):
if isinstance(value, bool):
rows.append({"key": key, "ordinal": 0, "val_bool": bool(value)})
elif isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
rows.append({"key": key, "ordinal": 0, "val_num": num})
elif isinstance(value, str):
rows.append({"key": key, "ordinal": 0, "val_str": value})
else:
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
if isinstance(value, list):
if all(is_scalar(x) for x in value):
for i, x in enumerate(value):
if x is None:
rows.append(_null_row(i))
elif isinstance(x, bool):
rows.append({"key": key, "ordinal": i, "val_bool": bool(x)})
elif isinstance(x, (int, float, Decimal)):
num = x if isinstance(x, Decimal) else Decimal(str(x))
rows.append({"key": key, "ordinal": i, "val_num": num})
elif isinstance(x, str):
rows.append({"key": key, "ordinal": i, "val_str": x})
else:
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
for i, x in enumerate(value):
rows.append({"key": key, "ordinal": i, "val_json": x})
return rows
rows.append({"key": key, "ordinal": 0, "val_json": value})
return rows
s = s.strip().lower()
if not s or ":" not in s:
raise ValueError("hash must be 'blake3:<hex>'")
algo, digest = s.split(":", 1)
if (
algo != "blake3"
or len(digest) != 64
or any(c for c in digest if c not in "0123456789abcdef")
):
raise ValueError("hash must be 'blake3:<hex>'")
return f"{algo}:{digest}"

View File

@ -1,516 +0,0 @@
import os
import mimetypes
import contextlib
from typing import Sequence
from app.database.db import create_session
from app.assets.api import schemas_out, schemas_in
from app.assets.database.queries import (
asset_exists_by_hash,
asset_info_exists_for_asset_id,
get_asset_by_hash,
get_asset_info_by_id,
fetch_asset_info_asset_and_tags,
fetch_asset_info_and_asset,
create_asset_info_for_existing_asset,
touch_asset_info_by_id,
update_asset_info_full,
delete_asset_info_by_id,
list_cache_states_by_asset_id,
list_asset_infos_page,
list_tags_with_usage,
get_asset_tags,
add_tags_to_asset_info,
remove_tags_from_asset_info,
pick_best_live_path,
ingest_fs_asset,
set_asset_info_preview,
)
from app.assets.helpers import resolve_destination_from_tags, ensure_within_base
from app.assets.database.models import Asset
def _safe_sort_field(requested: str | None) -> str:
if not requested:
return "created_at"
v = requested.lower()
if v in {"name", "created_at", "updated_at", "size", "last_access_time"}:
return v
return "created_at"
def _get_size_mtime_ns(path: str) -> tuple[int, int]:
st = os.stat(path, follow_symlinks=True)
return st.st_size, getattr(st, "st_mtime_ns", int(st.st_mtime * 1_000_000_000))
def _safe_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
if n:
return n
return fallback
def asset_exists(*, asset_hash: str) -> bool:
"""
Check if an asset with a given hash exists in database.
"""
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def list_assets(
*,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
owner_id: str = "",
) -> schemas_out.AssetsList:
sort = _safe_sort_field(sort)
order = "desc" if (order or "desc").lower() not in {"asc", "desc"} else order.lower()
with create_session() as session:
infos, tag_map, total = list_asset_infos_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
summaries: list[schemas_out.AssetSummary] = []
for info in infos:
asset = info.asset
tags = tag_map.get(info.id, [])
summaries.append(
schemas_out.AssetSummary(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
created_at=info.created_at,
updated_at=info.updated_at,
last_access_time=info.last_access_time,
)
)
return schemas_out.AssetsList(
assets=summaries,
total=total,
has_more=(offset + len(summaries)) < total,
)
def get_asset(
*,
asset_info_id: str,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset, tag_names = res
preview_id = info.preview_id
return schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
def resolve_asset_content_for_download(
*,
asset_info_id: str,
owner_id: str = "",
) -> tuple[str, str, str]:
with create_session() as session:
pair = fetch_asset_info_and_asset(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not pair:
raise ValueError(f"AssetInfo {asset_info_id} not found")
info, asset = pair
states = list_cache_states_by_asset_id(session, asset_id=asset.id)
abs_path = pick_best_live_path(states)
if not abs_path:
raise FileNotFoundError
touch_asset_info_by_id(session, asset_info_id=asset_info_id)
session.commit()
ctype = asset.mime_type or mimetypes.guess_type(info.name or abs_path)[0] or "application/octet-stream"
download_name = info.name or os.path.basename(abs_path)
return abs_path, ctype, download_name
def upload_asset_from_temp_path(
spec: schemas_in.UploadAssetSpec,
*,
temp_path: str,
client_filename: str | None = None,
owner_id: str = "",
expected_asset_hash: str | None = None,
) -> schemas_out.AssetCreated:
"""
Create new asset or update existing asset from a temporary file path.
"""
try:
# NOTE: blake3 is not required right now, so this will fail if blake3 is not installed in local environment
import app.assets.hashing as hashing
digest = hashing.blake3_hash(temp_path)
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_asset_hash and asset_hash != expected_asset_hash.strip().lower():
raise ValueError("HASH_MISMATCH")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _safe_filename(spec.name or (client_filename or ""), fallback=digest)
info = create_asset_info_for_existing_asset(
session,
asset_hash=asset_hash,
name=display_name,
user_metadata=spec.user_metadata or {},
tags=spec.tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
session.commit()
return schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=existing.hash,
size=int(existing.size_bytes) if existing.size_bytes is not None else None,
mime_type=existing.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
base_dir, subdirs = resolve_destination_from_tags(spec.tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or spec.name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
ensure_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = _get_size_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
with create_session() as session:
result = ingest_fs_asset(
session,
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_safe_filename(spec.name or (client_filename or ""), fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=spec.user_metadata or {},
tags=spec.tags,
tag_origin="manual",
require_existing_tags=False,
)
info_id = result["asset_info_id"]
if not info_id:
raise RuntimeError("failed to create asset metadata")
pair = fetch_asset_info_and_asset(session, asset_info_id=info_id, owner_id=owner_id)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
info, asset = pair
tag_names = get_asset_tags(session, asset_info_id=info.id)
created_result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=result["asset_created"],
)
session.commit()
return created_result
def update_asset(
*,
asset_info_id: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetUpdated:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
info = update_asset_info_full(
session,
asset_info_id=asset_info_id,
name=name,
tags=tags,
user_metadata=user_metadata,
tag_origin="manual",
asset_info_row=info_row,
)
tag_names = get_asset_tags(session, asset_info_id=asset_info_id)
result = schemas_out.AssetUpdated(
id=info.id,
name=info.name,
asset_hash=info.asset.hash if info.asset else None,
tags=tag_names,
user_metadata=info.user_metadata or {},
updated_at=info.updated_at,
)
session.commit()
return result
def set_asset_preview(
*,
asset_info_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> schemas_out.AssetDetail:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
set_asset_info_preview(
session,
asset_info_id=asset_info_id,
preview_asset_id=preview_asset_id,
)
res = fetch_asset_info_asset_and_tags(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not res:
raise RuntimeError("State changed during preview update")
info, asset, tags = res
result = schemas_out.AssetDetail(
id=info.id,
name=info.name,
asset_hash=asset.hash if asset else None,
size=int(asset.size_bytes) if asset and asset.size_bytes is not None else None,
mime_type=asset.mime_type if asset else None,
tags=tags,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
)
session.commit()
return result
def delete_asset_reference(*, asset_info_id: str, owner_id: str, delete_content_if_orphan: bool = True) -> bool:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
asset_id = info_row.asset_id if info_row else None
deleted = delete_asset_info_by_id(session, asset_info_id=asset_info_id, owner_id=owner_id)
if not deleted:
session.commit()
return False
if not delete_content_if_orphan or not asset_id:
session.commit()
return True
still_exists = asset_info_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
states = list_cache_states_by_asset_id(session, asset_id=asset_id)
file_paths = [s.file_path for s in (states or []) if getattr(s, "file_path", None)]
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def create_asset_from_hash(
*,
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> schemas_out.AssetCreated | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
info = create_asset_info_for_existing_asset(
session,
asset_hash=canonical,
name=_safe_filename(name, fallback=canonical.split(":", 1)[1]),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
tag_names = get_asset_tags(session, asset_info_id=info.id)
result = schemas_out.AssetCreated(
id=info.id,
name=info.name,
asset_hash=asset.hash,
size=int(asset.size_bytes),
mime_type=asset.mime_type,
tags=tag_names,
user_metadata=info.user_metadata or {},
preview_id=info.preview_id,
created_at=info.created_at,
last_access_time=info.last_access_time,
created_new=False,
)
session.commit()
return result
def add_tags_to_asset(
*,
asset_info_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> schemas_out.TagsAdd:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = add_tags_to_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
origin=origin,
create_if_missing=True,
asset_info_row=info_row,
)
session.commit()
return schemas_out.TagsAdd(**data)
def remove_tags_from_asset(
*,
asset_info_id: str,
tags: list[str],
owner_id: str = "",
) -> schemas_out.TagsRemove:
with create_session() as session:
info_row = get_asset_info_by_id(session, asset_info_id=asset_info_id)
if not info_row:
raise ValueError(f"AssetInfo {asset_info_id} not found")
if info_row.owner_id and info_row.owner_id != owner_id:
raise PermissionError("not owner")
data = remove_tags_from_asset_info(
session,
asset_info_id=asset_info_id,
tags=tags,
)
session.commit()
return schemas_out.TagsRemove(**data)
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> schemas_out.TagsList:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
tags = [schemas_out.TagUsage(name=name, count=count, type=tag_type) for (name, tag_type, count) in rows]
return schemas_out.TagsList(tags=tags, total=total, has_more=(offset + len(tags)) < total)

View File

@ -1,263 +1,567 @@
import contextlib
import time
import logging
import os
import sqlalchemy
from pathlib import Path
from typing import Callable, Literal, TypedDict
import folder_paths
from app.database.db import create_session, dependencies_available
from app.assets.helpers import (
collect_models_files, compute_relative_filename, fast_asset_file_check, get_name_and_tags_from_asset_path,
list_tree,prefixes_for_root, escape_like_prefix,
RootType
from app.assets.database.queries import (
add_missing_tag_for_asset_id,
bulk_update_enrichment_level,
bulk_update_is_missing,
bulk_update_needs_verify,
delete_orphaned_seed_asset,
delete_references_by_ids,
ensure_tags_exist,
get_asset_by_hash,
get_references_for_prefixes,
get_unenriched_references,
mark_references_missing_outside_prefixes,
reassign_asset_references,
remove_missing_tag_for_asset_id,
set_reference_metadata,
update_asset_hash_and_mime,
)
from app.assets.database.tags import add_missing_tag_for_asset_id, ensure_tags_exist, remove_missing_tag_for_asset_id
from app.assets.database.bulk_ops import seed_from_paths_batch
from app.assets.database.models import Asset, AssetCacheState, AssetInfo
from app.assets.services.bulk_ingest import (
SeedAssetSpec,
batch_insert_seed_assets,
)
from app.assets.services.file_utils import (
get_mtime_ns,
is_visible,
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
from app.assets.services.metadata_extract import extract_file_metadata
from app.assets.services.path_utils import (
compute_relative_filename,
get_comfy_models_folders,
get_name_and_tags_from_asset_path,
)
from app.database.db import create_session
def seed_assets(roots: tuple[RootType, ...], enable_logging: bool = False) -> None:
"""
Scan the given roots and seed the assets into the database.
"""
if not dependencies_available():
if enable_logging:
logging.warning("Database dependencies not available, skipping assets scan")
return
t_start = time.perf_counter()
created = 0
skipped_existing = 0
orphans_pruned = 0
paths: list[str] = []
try:
existing_paths: set[str] = set()
for r in roots:
try:
survivors: set[str] = _fast_db_consistency_pass(r, collect_existing_paths=True, update_missing_tags=True)
if survivors:
existing_paths.update(survivors)
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", r, e)
class _RefInfo(TypedDict):
ref_id: str
file_path: str
exists: bool
stat_unchanged: bool
needs_verify: bool
try:
orphans_pruned = _prune_orphaned_assets(roots)
except Exception as e:
logging.exception("orphan pruning failed: %s", e)
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_tree(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_tree(folder_paths.get_output_directory()))
class _AssetAccumulator(TypedDict):
hash: str | None
size_db: int
refs: list[_RefInfo]
specs: list[dict] = []
tag_pool: set[str] = set()
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped_existing += 1
RootType = Literal["models", "input", "output"]
def get_prefixes_for_root(root: RootType) -> list[str]:
if root == "models":
bases: list[str] = []
for _bucket, paths in get_comfy_models_folders():
bases.extend(paths)
return [os.path.abspath(p) for p in bases]
if root == "input":
return [os.path.abspath(folder_paths.get_input_directory())]
if root == "output":
return [os.path.abspath(folder_paths.get_output_directory())]
return []
def get_all_known_prefixes() -> list[str]:
"""Get all known asset prefixes across all root types."""
all_roots: tuple[RootType, ...] = ("models", "input", "output")
return [p for root in all_roots for p in get_prefixes_for_root(root)]
def collect_models_files() -> list[str]:
out: list[str] = []
for folder_name, bases in get_comfy_models_folders():
rel_files = folder_paths.get_filename_list(folder_name) or []
for rel_path in rel_files:
if not all(is_visible(part) for part in Path(rel_path).parts):
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=False)
except OSError:
abs_path = folder_paths.get_full_path(folder_name, rel_path)
if not abs_path:
continue
# skip empty files
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": getattr(stat_p, "st_mtime_ns", int(stat_p.st_mtime * 1_000_000_000)),
"info_name": name,
"tags": tags,
"fname": compute_relative_filename(abs_p),
}
)
for t in tags:
tag_pool.add(t)
# if no file specs, nothing to do
if not specs:
return
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = seed_from_paths_batch(sess, specs=specs, owner_id="")
created += result["inserted_infos"]
sess.commit()
finally:
if enable_logging:
logging.info(
"Assets scan(roots=%s) completed in %.3fs (created=%d, skipped_existing=%d, orphans_pruned=%d, total_seen=%d)",
roots,
time.perf_counter() - t_start,
created,
skipped_existing,
orphans_pruned,
len(paths),
)
abs_path = os.path.abspath(abs_path)
allowed = False
abs_p = Path(abs_path)
for b in bases:
if abs_p.is_relative_to(os.path.abspath(b)):
allowed = True
break
if allowed:
out.append(abs_path)
return out
def _prune_orphaned_assets(roots: tuple[RootType, ...]) -> int:
"""Prune cache states outside configured prefixes, then delete orphaned seed assets."""
all_prefixes = [os.path.abspath(p) for r in roots for p in prefixes_for_root(r)]
if not all_prefixes:
return 0
def make_prefix_condition(prefix: str):
base = prefix if prefix.endswith(os.sep) else prefix + os.sep
escaped, esc = escape_like_prefix(base)
return AssetCacheState.file_path.like(escaped + "%", escape=esc)
matches_valid_prefix = sqlalchemy.or_(*[make_prefix_condition(p) for p in all_prefixes])
orphan_subq = (
sqlalchemy.select(Asset.id)
.outerjoin(AssetCacheState, AssetCacheState.asset_id == Asset.id)
.where(Asset.hash.is_(None), AssetCacheState.id.is_(None))
).scalar_subquery()
with create_session() as sess:
sess.execute(sqlalchemy.delete(AssetCacheState).where(~matches_valid_prefix))
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id.in_(orphan_subq)))
result = sess.execute(sqlalchemy.delete(Asset).where(Asset.id.in_(orphan_subq)))
sess.commit()
return result.rowcount
def _fast_db_consistency_pass(
def sync_references_with_filesystem(
session,
root: RootType,
*,
collect_existing_paths: bool = False,
update_missing_tags: bool = False,
) -> set[str] | None:
"""Fast DB+FS pass for a root:
- Toggle needs_verify per state using fast check
- For hashed assets with at least one fast-ok state in this root: delete stale missing states
- For seed assets with all states missing: delete Asset and its AssetInfos
- Optionally add/remove 'missing' tags based on fast-ok in this root
- Optionally return surviving absolute paths
"""Reconcile asset references with filesystem for a root.
- Toggle needs_verify per reference using mtime/size stat check
- For hashed assets with at least one stat-unchanged ref: delete stale missing refs
- For seed assets with all refs missing: delete Asset and its references
- Optionally add/remove 'missing' tags based on stat check in this root
- Optionally return surviving absolute paths
Args:
session: Database session
root: Root type to scan
collect_existing_paths: If True, return set of surviving file paths
update_missing_tags: If True, update 'missing' tags based on file status
Returns:
Set of surviving absolute paths if collect_existing_paths=True, else None
"""
prefixes = prefixes_for_root(root)
prefixes = get_prefixes_for_root(root)
if not prefixes:
return set() if collect_existing_paths else None
conds = []
for p in prefixes:
base = os.path.abspath(p)
if not base.endswith(os.sep):
base += os.sep
escaped, esc = escape_like_prefix(base)
conds.append(AssetCacheState.file_path.like(escaped + "%", escape=esc))
rows = get_references_for_prefixes(
session, prefixes, include_missing=update_missing_tags
)
by_asset: dict[str, _AssetAccumulator] = {}
for row in rows:
acc = by_asset.get(row.asset_id)
if acc is None:
acc = {"hash": row.asset_hash, "size_db": row.size_bytes, "refs": []}
by_asset[row.asset_id] = acc
stat_unchanged = False
try:
exists = True
stat_unchanged = verify_file_unchanged(
mtime_db=row.mtime_ns,
size_db=acc["size_db"],
stat_result=os.stat(row.file_path, follow_symlinks=True),
)
except FileNotFoundError:
exists = False
except PermissionError:
exists = True
logging.debug("Permission denied accessing %s", row.file_path)
except OSError as e:
exists = False
logging.debug("OSError checking %s: %s", row.file_path, e)
acc["refs"].append(
{
"ref_id": row.reference_id,
"file_path": row.file_path,
"exists": exists,
"stat_unchanged": stat_unchanged,
"needs_verify": row.needs_verify,
}
)
to_set_verify: list[str] = []
to_clear_verify: list[str] = []
stale_ref_ids: list[str] = []
to_mark_missing: list[str] = []
to_clear_missing: list[str] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
refs = acc["refs"]
any_unchanged = any(r["stat_unchanged"] for r in refs)
all_missing = all(not r["exists"] for r in refs)
for r in refs:
if not r["exists"]:
to_mark_missing.append(r["ref_id"])
continue
if r["stat_unchanged"]:
to_clear_missing.append(r["ref_id"])
if r["needs_verify"]:
to_clear_verify.append(r["ref_id"])
if not r["stat_unchanged"] and not r["needs_verify"]:
to_set_verify.append(r["ref_id"])
if a_hash is None:
if refs and all_missing:
delete_orphaned_seed_asset(session, aid)
else:
for r in refs:
if r["exists"]:
survivors.add(os.path.abspath(r["file_path"]))
continue
if any_unchanged:
for r in refs:
if not r["exists"]:
stale_ref_ids.append(r["ref_id"])
if update_missing_tags:
try:
remove_missing_tag_for_asset_id(session, asset_id=aid)
except Exception as e:
logging.warning(
"Failed to remove missing tag for asset %s: %s", aid, e
)
elif update_missing_tags:
try:
add_missing_tag_for_asset_id(session, asset_id=aid, origin="automatic")
except Exception as e:
logging.warning("Failed to add missing tag for asset %s: %s", aid, e)
for r in refs:
if r["exists"]:
survivors.add(os.path.abspath(r["file_path"]))
delete_references_by_ids(session, stale_ref_ids)
stale_set = set(stale_ref_ids)
to_mark_missing = [ref_id for ref_id in to_mark_missing if ref_id not in stale_set]
bulk_update_is_missing(session, to_mark_missing, value=True)
bulk_update_is_missing(session, to_clear_missing, value=False)
bulk_update_needs_verify(session, to_set_verify, value=True)
bulk_update_needs_verify(session, to_clear_verify, value=False)
return survivors if collect_existing_paths else None
def sync_root_safely(root: RootType) -> set[str]:
"""Sync a single root's references with the filesystem.
Returns survivors (existing paths) or empty set on failure.
"""
try:
with create_session() as sess:
survivors = sync_references_with_filesystem(
sess,
root,
collect_existing_paths=True,
update_missing_tags=True,
)
sess.commit()
return survivors or set()
except Exception as e:
logging.exception("fast DB scan failed for %s: %s", root, e)
return set()
def mark_missing_outside_prefixes_safely(prefixes: list[str]) -> int:
"""Mark references as missing when outside the given prefixes.
This is a non-destructive soft-delete. Returns count marked or 0 on failure.
"""
try:
with create_session() as sess:
count = mark_references_missing_outside_prefixes(sess, prefixes)
sess.commit()
return count
except Exception as e:
logging.exception("marking missing assets failed: %s", e)
return 0
def collect_paths_for_roots(roots: tuple[RootType, ...]) -> list[str]:
"""Collect all file paths for the given roots."""
paths: list[str] = []
if "models" in roots:
paths.extend(collect_models_files())
if "input" in roots:
paths.extend(list_files_recursively(folder_paths.get_input_directory()))
if "output" in roots:
paths.extend(list_files_recursively(folder_paths.get_output_directory()))
return paths
def build_asset_specs(
paths: list[str],
existing_paths: set[str],
enable_metadata_extraction: bool = True,
compute_hashes: bool = False,
) -> tuple[list[SeedAssetSpec], set[str], int]:
"""Build asset specs from paths, returning (specs, tag_pool, skipped_count).
Args:
paths: List of file paths to process
existing_paths: Set of paths that already exist in the database
enable_metadata_extraction: If True, extract tier 1 & 2 metadata
compute_hashes: If True, compute blake3 hashes (slow for large files)
"""
specs: list[SeedAssetSpec] = []
tag_pool: set[str] = set()
skipped = 0
for p in paths:
abs_p = os.path.abspath(p)
if abs_p in existing_paths:
skipped += 1
continue
try:
stat_p = os.stat(abs_p, follow_symlinks=True)
except OSError:
continue
if not stat_p.st_size:
continue
name, tags = get_name_and_tags_from_asset_path(abs_p)
rel_fname = compute_relative_filename(abs_p)
# Extract metadata (tier 1: filesystem, tier 2: safetensors header)
metadata = None
if enable_metadata_extraction:
metadata = extract_file_metadata(
abs_p,
stat_result=stat_p,
relative_filename=rel_fname,
)
# Compute hash if requested
asset_hash: str | None = None
if compute_hashes:
try:
digest, _ = compute_blake3_hash(abs_p)
asset_hash = "blake3:" + digest
except Exception as e:
logging.warning("Failed to hash %s: %s", abs_p, e)
mime_type = metadata.content_type if metadata else None
specs.append(
{
"abs_path": abs_p,
"size_bytes": stat_p.st_size,
"mtime_ns": get_mtime_ns(stat_p),
"info_name": name,
"tags": tags,
"fname": rel_fname,
"metadata": metadata,
"hash": asset_hash,
"mime_type": mime_type,
}
)
tag_pool.update(tags)
return specs, tag_pool, skipped
def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
"""Insert asset specs into database, returning count of created refs."""
if not specs:
return 0
with create_session() as sess:
if tag_pool:
ensure_tags_exist(sess, tag_pool, tag_type="user")
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
sess.commit()
return result.inserted_refs
# Enrichment level constants
ENRICHMENT_STUB = 0 # Fast scan: path, size, mtime only
ENRICHMENT_METADATA = 1 # Metadata extracted (safetensors header, mime type)
ENRICHMENT_HASHED = 2 # Hash computed (blake3)
def get_unenriched_assets_for_roots(
roots: tuple[RootType, ...],
max_level: int = ENRICHMENT_STUB,
limit: int = 1000,
) -> list:
"""Get assets that need enrichment for the given roots.
Args:
roots: Tuple of root types to scan
max_level: Maximum enrichment level to include
limit: Maximum number of rows to return
Returns:
List of UnenrichedReferenceRow
"""
prefixes: list[str] = []
for root in roots:
prefixes.extend(get_prefixes_for_root(root))
if not prefixes:
return []
with create_session() as sess:
rows = (
sess.execute(
sqlalchemy.select(
AssetCacheState.id,
AssetCacheState.file_path,
AssetCacheState.mtime_ns,
AssetCacheState.needs_verify,
AssetCacheState.asset_id,
Asset.hash,
Asset.size_bytes,
)
.join(Asset, Asset.id == AssetCacheState.asset_id)
.where(sqlalchemy.or_(*conds))
.order_by(AssetCacheState.asset_id.asc(), AssetCacheState.id.asc())
return get_unenriched_references(
sess, prefixes, max_level=max_level, limit=limit
)
def enrich_asset(
session,
file_path: str,
reference_id: str,
asset_id: str,
extract_metadata: bool = True,
compute_hash: bool = False,
interrupt_check: Callable[[], bool] | None = None,
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
) -> int:
"""Enrich a single asset with metadata and/or hash.
Args:
session: Database session (caller manages lifecycle)
file_path: Absolute path to the file
reference_id: ID of the reference to update
asset_id: ID of the asset to update (for mime_type and hash)
extract_metadata: If True, extract safetensors header and mime type
compute_hash: If True, compute blake3 hash
interrupt_check: Optional non-blocking callable that returns True if
the operation should be interrupted (e.g. paused or cancelled)
hash_checkpoints: Optional dict for saving/restoring hash progress
across interruptions, keyed by file path
Returns:
New enrichment level achieved
"""
new_level = ENRICHMENT_STUB
try:
stat_p = os.stat(file_path, follow_symlinks=True)
except OSError:
return new_level
rel_fname = compute_relative_filename(file_path)
mime_type: str | None = None
metadata = None
if extract_metadata:
metadata = extract_file_metadata(
file_path,
stat_result=stat_p,
relative_filename=rel_fname,
)
if metadata:
mime_type = metadata.content_type
new_level = ENRICHMENT_METADATA
full_hash: str | None = None
if compute_hash:
try:
mtime_before = get_mtime_ns(stat_p)
size_before = stat_p.st_size
# Restore checkpoint if available and file unchanged
checkpoint = None
if hash_checkpoints is not None:
checkpoint = hash_checkpoints.get(file_path)
if checkpoint is not None:
cur_stat = os.stat(file_path, follow_symlinks=True)
if (checkpoint.mtime_ns != get_mtime_ns(cur_stat)
or checkpoint.file_size != cur_stat.st_size):
checkpoint = None
hash_checkpoints.pop(file_path, None)
else:
mtime_before = get_mtime_ns(cur_stat)
digest, new_checkpoint = compute_blake3_hash(
file_path,
interrupt_check=interrupt_check,
checkpoint=checkpoint,
)
).all()
by_asset: dict[str, dict] = {}
for sid, fp, mtime_db, needs_verify, aid, a_hash, a_size in rows:
acc = by_asset.get(aid)
if acc is None:
acc = {"hash": a_hash, "size_db": int(a_size or 0), "states": []}
by_asset[aid] = acc
if digest is None:
# Interrupted — save checkpoint for later resumption
if hash_checkpoints is not None and new_checkpoint is not None:
new_checkpoint.mtime_ns = mtime_before
new_checkpoint.file_size = size_before
hash_checkpoints[file_path] = new_checkpoint
return new_level
# Completed — clear any saved checkpoint
if hash_checkpoints is not None:
hash_checkpoints.pop(file_path, None)
stat_after = os.stat(file_path, follow_symlinks=True)
mtime_after = get_mtime_ns(stat_after)
if mtime_before != mtime_after:
logging.warning("File modified during hashing, discarding hash: %s", file_path)
else:
full_hash = f"blake3:{digest}"
metadata_ok = not extract_metadata or metadata is not None
if metadata_ok:
new_level = ENRICHMENT_HASHED
except Exception as e:
logging.warning("Failed to hash %s: %s", file_path, e)
if extract_metadata and metadata:
user_metadata = metadata.to_user_metadata()
set_reference_metadata(session, reference_id, user_metadata)
if full_hash:
existing = get_asset_by_hash(session, full_hash)
if existing and existing.id != asset_id:
reassign_asset_references(session, asset_id, existing.id, reference_id)
delete_orphaned_seed_asset(session, asset_id)
if mime_type:
update_asset_hash_and_mime(session, existing.id, mime_type=mime_type)
else:
update_asset_hash_and_mime(session, asset_id, full_hash, mime_type)
elif mime_type:
update_asset_hash_and_mime(session, asset_id, mime_type=mime_type)
bulk_update_enrichment_level(session, [reference_id], new_level)
session.commit()
return new_level
def enrich_assets_batch(
rows: list,
extract_metadata: bool = True,
compute_hash: bool = False,
interrupt_check: Callable[[], bool] | None = None,
hash_checkpoints: dict[str, HashCheckpoint] | None = None,
) -> tuple[int, list[str]]:
"""Enrich a batch of assets.
Uses a single DB session for the entire batch, committing after each
individual asset to avoid long-held transactions while eliminating
per-asset session creation overhead.
Args:
rows: List of UnenrichedReferenceRow from get_unenriched_assets_for_roots
extract_metadata: If True, extract metadata for each asset
compute_hash: If True, compute hash for each asset
interrupt_check: Optional non-blocking callable that returns True if
the operation should be interrupted (e.g. paused or cancelled)
hash_checkpoints: Optional dict for saving/restoring hash progress
across interruptions, keyed by file path
Returns:
Tuple of (enriched_count, failed_reference_ids)
"""
enriched = 0
failed_ids: list[str] = []
with create_session() as sess:
for row in rows:
if interrupt_check is not None and interrupt_check():
break
fast_ok = False
try:
exists = True
fast_ok = fast_asset_file_check(
mtime_db=mtime_db,
size_db=acc["size_db"],
stat_result=os.stat(fp, follow_symlinks=True),
new_level = enrich_asset(
sess,
file_path=row.file_path,
reference_id=row.reference_id,
asset_id=row.asset_id,
extract_metadata=extract_metadata,
compute_hash=compute_hash,
interrupt_check=interrupt_check,
hash_checkpoints=hash_checkpoints,
)
except FileNotFoundError:
exists = False
except OSError:
exists = False
acc["states"].append({
"sid": sid,
"fp": fp,
"exists": exists,
"fast_ok": fast_ok,
"needs_verify": bool(needs_verify),
})
to_set_verify: list[int] = []
to_clear_verify: list[int] = []
stale_state_ids: list[int] = []
survivors: set[str] = set()
for aid, acc in by_asset.items():
a_hash = acc["hash"]
states = acc["states"]
any_fast_ok = any(s["fast_ok"] for s in states)
all_missing = all(not s["exists"] for s in states)
for s in states:
if not s["exists"]:
continue
if s["fast_ok"] and s["needs_verify"]:
to_clear_verify.append(s["sid"])
if not s["fast_ok"] and not s["needs_verify"]:
to_set_verify.append(s["sid"])
if a_hash is None:
if states and all_missing: # remove seed Asset completely, if no valid AssetCache exists
sess.execute(sqlalchemy.delete(AssetInfo).where(AssetInfo.asset_id == aid))
asset = sess.get(Asset, aid)
if asset:
sess.delete(asset)
if new_level > row.enrichment_level:
enriched += 1
else:
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
continue
failed_ids.append(row.reference_id)
except Exception as e:
logging.warning("Failed to enrich %s: %s", row.file_path, e)
sess.rollback()
failed_ids.append(row.reference_id)
if any_fast_ok: # if Asset has at least one valid AssetCache record, remove any invalid AssetCache records
for s in states:
if not s["exists"]:
stale_state_ids.append(s["sid"])
if update_missing_tags:
with contextlib.suppress(Exception):
remove_missing_tag_for_asset_id(sess, asset_id=aid)
elif update_missing_tags:
with contextlib.suppress(Exception):
add_missing_tag_for_asset_id(sess, asset_id=aid, origin="automatic")
for s in states:
if s["exists"]:
survivors.add(os.path.abspath(s["fp"]))
if stale_state_ids:
sess.execute(sqlalchemy.delete(AssetCacheState).where(AssetCacheState.id.in_(stale_state_ids)))
if to_set_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_set_verify))
.values(needs_verify=True)
)
if to_clear_verify:
sess.execute(
sqlalchemy.update(AssetCacheState)
.where(AssetCacheState.id.in_(to_clear_verify))
.values(needs_verify=False)
)
sess.commit()
return survivors if collect_existing_paths else None
return enriched, failed_ids

794
app/assets/seeder.py Normal file
View File

@ -0,0 +1,794 @@
"""Background asset seeder with thread management and cancellation support."""
import logging
import os
import threading
import time
from dataclasses import dataclass, field
from enum import Enum
from typing import Callable
from app.assets.scanner import (
ENRICHMENT_METADATA,
ENRICHMENT_STUB,
RootType,
build_asset_specs,
collect_paths_for_roots,
enrich_assets_batch,
get_all_known_prefixes,
get_prefixes_for_root,
get_unenriched_assets_for_roots,
insert_asset_specs,
mark_missing_outside_prefixes_safely,
sync_root_safely,
)
from app.database.db import dependencies_available
class ScanInProgressError(Exception):
"""Raised when an operation cannot proceed because a scan is running."""
class State(Enum):
"""Seeder state machine states."""
IDLE = "IDLE"
RUNNING = "RUNNING"
PAUSED = "PAUSED"
CANCELLING = "CANCELLING"
class ScanPhase(Enum):
"""Scan phase options."""
FAST = "fast" # Phase 1: filesystem only (stubs)
ENRICH = "enrich" # Phase 2: metadata + hash
FULL = "full" # Both phases sequentially
@dataclass
class Progress:
"""Progress information for a scan operation."""
scanned: int = 0
total: int = 0
created: int = 0
skipped: int = 0
@dataclass
class ScanStatus:
"""Current status of the asset seeder."""
state: State
progress: Progress | None
errors: list[str] = field(default_factory=list)
ProgressCallback = Callable[[Progress], None]
class _AssetSeeder:
"""Background asset scanning manager.
Spawns ephemeral daemon threads for scanning.
Each scan creates a new thread that exits when complete.
Use the module-level ``asset_seeder`` instance.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._state = State.IDLE
self._progress: Progress | None = None
self._last_progress: Progress | None = None
self._errors: list[str] = []
self._thread: threading.Thread | None = None
self._cancel_event = threading.Event()
self._run_gate = threading.Event()
self._run_gate.set() # Start unpaused (set = running, clear = paused)
self._roots: tuple[RootType, ...] = ()
self._phase: ScanPhase = ScanPhase.FULL
self._compute_hashes: bool = False
self._prune_first: bool = False
self._progress_callback: ProgressCallback | None = None
self._disabled: bool = False
def disable(self) -> None:
"""Disable the asset seeder, preventing any scans from starting."""
self._disabled = True
logging.info("Asset seeder disabled")
def is_disabled(self) -> bool:
"""Check if the asset seeder is disabled."""
return self._disabled
def start(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
phase: ScanPhase = ScanPhase.FULL,
progress_callback: ProgressCallback | None = None,
prune_first: bool = False,
compute_hashes: bool = False,
) -> bool:
"""Start a background scan for the given roots.
Args:
roots: Tuple of root types to scan (models, input, output)
phase: Scan phase to run (FAST, ENRICH, or FULL for both)
progress_callback: Optional callback called with progress updates
prune_first: If True, prune orphaned assets before scanning
compute_hashes: If True, compute blake3 hashes (slow)
Returns:
True if scan was started, False if already running
"""
if self._disabled:
logging.debug("Asset seeder is disabled, skipping start")
return False
logging.info("Seeder start (roots=%s, phase=%s)", roots, phase.value)
with self._lock:
if self._state != State.IDLE:
logging.info("Asset seeder already running, skipping start")
return False
self._state = State.RUNNING
self._progress = Progress()
self._errors = []
self._roots = roots
self._phase = phase
self._prune_first = prune_first
self._compute_hashes = compute_hashes
self._progress_callback = progress_callback
self._cancel_event.clear()
self._run_gate.set() # Ensure unpaused when starting
self._thread = threading.Thread(
target=self._run_scan,
name="_AssetSeeder",
daemon=True,
)
self._thread.start()
return True
def start_fast(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
progress_callback: ProgressCallback | None = None,
prune_first: bool = False,
) -> bool:
"""Start a fast scan (phase 1 only) - creates stub records.
Args:
roots: Tuple of root types to scan
progress_callback: Optional callback for progress updates
prune_first: If True, prune orphaned assets before scanning
Returns:
True if scan was started, False if already running
"""
return self.start(
roots=roots,
phase=ScanPhase.FAST,
progress_callback=progress_callback,
prune_first=prune_first,
compute_hashes=False,
)
def start_enrich(
self,
roots: tuple[RootType, ...] = ("models", "input", "output"),
progress_callback: ProgressCallback | None = None,
compute_hashes: bool = False,
) -> bool:
"""Start an enrichment scan (phase 2 only) - extracts metadata and hashes.
Args:
roots: Tuple of root types to scan
progress_callback: Optional callback for progress updates
compute_hashes: If True, compute blake3 hashes
Returns:
True if scan was started, False if already running
"""
return self.start(
roots=roots,
phase=ScanPhase.ENRICH,
progress_callback=progress_callback,
prune_first=False,
compute_hashes=compute_hashes,
)
def cancel(self) -> bool:
"""Request cancellation of the current scan.
Returns:
True if cancellation was requested, False if not running or paused
"""
with self._lock:
if self._state not in (State.RUNNING, State.PAUSED):
return False
logging.info("Asset seeder cancelling (was %s)", self._state.value)
self._state = State.CANCELLING
self._cancel_event.set()
self._run_gate.set() # Unblock if paused so thread can exit
return True
def stop(self) -> bool:
"""Stop the current scan (alias for cancel).
Returns:
True if stop was requested, False if not running
"""
return self.cancel()
def pause(self) -> bool:
"""Pause the current scan.
The scan will complete its current batch before pausing.
Returns:
True if pause was requested, False if not running
"""
with self._lock:
if self._state != State.RUNNING:
return False
logging.info("Asset seeder pausing")
self._state = State.PAUSED
self._run_gate.clear()
return True
def resume(self) -> bool:
"""Resume a paused scan.
This is a noop if the scan is not in the PAUSED state
Returns:
True if resumed, False if not paused
"""
with self._lock:
if self._state != State.PAUSED:
return False
logging.info("Asset seeder resuming")
self._state = State.RUNNING
self._run_gate.set()
self._emit_event("assets.seed.resumed", {})
return True
def restart(
self,
roots: tuple[RootType, ...] | None = None,
phase: ScanPhase | None = None,
progress_callback: ProgressCallback | None = None,
prune_first: bool | None = None,
compute_hashes: bool | None = None,
timeout: float = 5.0,
) -> bool:
"""Cancel any running scan and start a new one.
Args:
roots: Roots to scan (defaults to previous roots)
phase: Scan phase (defaults to previous phase)
progress_callback: Progress callback (defaults to previous)
prune_first: Prune before scan (defaults to previous)
compute_hashes: Compute hashes (defaults to previous)
timeout: Max seconds to wait for current scan to stop
Returns:
True if new scan was started, False if failed to stop previous
"""
logging.info("Asset seeder restart requested")
with self._lock:
prev_roots = self._roots
prev_phase = self._phase
prev_callback = self._progress_callback
prev_prune = self._prune_first
prev_hashes = self._compute_hashes
self.cancel()
if not self.wait(timeout=timeout):
return False
cb = progress_callback if progress_callback is not None else prev_callback
return self.start(
roots=roots if roots is not None else prev_roots,
phase=phase if phase is not None else prev_phase,
progress_callback=cb,
prune_first=prune_first if prune_first is not None else prev_prune,
compute_hashes=(
compute_hashes if compute_hashes is not None else prev_hashes
),
)
def wait(self, timeout: float | None = None) -> bool:
"""Wait for the current scan to complete.
Args:
timeout: Maximum seconds to wait, or None for no timeout
Returns:
True if scan completed, False if timeout expired or no scan running
"""
with self._lock:
thread = self._thread
if thread is None:
return True
thread.join(timeout=timeout)
return not thread.is_alive()
def get_status(self) -> ScanStatus:
"""Get the current status and progress of the seeder."""
with self._lock:
src = self._progress or self._last_progress
return ScanStatus(
state=self._state,
progress=Progress(
scanned=src.scanned,
total=src.total,
created=src.created,
skipped=src.skipped,
)
if src
else None,
errors=list(self._errors),
)
def shutdown(self, timeout: float = 5.0) -> None:
"""Gracefully shutdown: cancel any running scan and wait for thread.
Args:
timeout: Maximum seconds to wait for thread to exit
"""
self.cancel()
self.wait(timeout=timeout)
with self._lock:
self._thread = None
def mark_missing_outside_prefixes(self) -> int:
"""Mark references as missing when outside all known root prefixes.
This is a non-destructive soft-delete operation. Assets and their
metadata are preserved, but references are flagged as missing.
They can be restored if the file reappears in a future scan.
This operation is decoupled from scanning to prevent partial scans
from accidentally marking assets belonging to other roots.
Should be called explicitly when cleanup is desired, typically after
a full scan of all roots or during maintenance.
Returns:
Number of references marked as missing
Raises:
ScanInProgressError: If a scan is currently running
"""
with self._lock:
if self._state != State.IDLE:
raise ScanInProgressError(
"Cannot mark missing assets while scan is running"
)
self._state = State.RUNNING
try:
if not dependencies_available():
logging.warning(
"Database dependencies not available, skipping mark missing"
)
return 0
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d references as missing", marked)
return marked
finally:
with self._lock:
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
def _is_cancelled(self) -> bool:
"""Check if cancellation has been requested."""
return self._cancel_event.is_set()
def _is_paused_or_cancelled(self) -> bool:
"""Non-blocking check: True if paused or cancelled.
Use as interrupt_check for I/O-bound work (e.g. hashing) so that
file handles are released immediately on pause rather than held
open while blocked. The caller is responsible for blocking on
_check_pause_and_cancel() afterward.
"""
return not self._run_gate.is_set() or self._cancel_event.is_set()
def _check_pause_and_cancel(self) -> bool:
"""Block while paused, then check if cancelled.
Call this at checkpoint locations in scan loops. It will:
1. Block indefinitely while paused (until resume or cancel)
2. Return True if cancelled, False to continue
Returns:
True if scan should stop, False to continue
"""
if not self._run_gate.is_set():
self._emit_event("assets.seed.paused", {})
self._run_gate.wait() # Blocks if paused
return self._is_cancelled()
def _emit_event(self, event_type: str, data: dict) -> None:
"""Emit a WebSocket event if server is available."""
try:
from server import PromptServer
if hasattr(PromptServer, "instance") and PromptServer.instance:
PromptServer.instance.send_sync(event_type, data)
except Exception:
pass
def _update_progress(
self,
scanned: int | None = None,
total: int | None = None,
created: int | None = None,
skipped: int | None = None,
) -> None:
"""Update progress counters (thread-safe)."""
callback: ProgressCallback | None = None
progress: Progress | None = None
with self._lock:
if self._progress is None:
return
if scanned is not None:
self._progress.scanned = scanned
if total is not None:
self._progress.total = total
if created is not None:
self._progress.created = created
if skipped is not None:
self._progress.skipped = skipped
if self._progress_callback:
callback = self._progress_callback
progress = Progress(
scanned=self._progress.scanned,
total=self._progress.total,
created=self._progress.created,
skipped=self._progress.skipped,
)
if callback and progress:
try:
callback(progress)
except Exception:
pass
_MAX_ERRORS = 200
def _add_error(self, message: str) -> None:
"""Add an error message (thread-safe), capped at _MAX_ERRORS."""
with self._lock:
if len(self._errors) < self._MAX_ERRORS:
self._errors.append(message)
def _log_scan_config(self, roots: tuple[RootType, ...]) -> None:
"""Log the directories that will be scanned."""
import folder_paths
for root in roots:
if root == "models":
logging.info(
"Asset scan [models] directory: %s",
os.path.abspath(folder_paths.models_dir),
)
else:
prefixes = get_prefixes_for_root(root)
if prefixes:
logging.info("Asset scan [%s] directories: %s", root, prefixes)
def _run_scan(self) -> None:
"""Main scan loop running in background thread."""
t_start = time.perf_counter()
roots = self._roots
phase = self._phase
cancelled = False
total_created = 0
total_enriched = 0
skipped_existing = 0
total_paths = 0
try:
if not dependencies_available():
self._add_error("Database dependencies not available")
self._emit_event(
"assets.seed.error",
{"message": "Database dependencies not available"},
)
return
if self._prune_first:
all_prefixes = get_all_known_prefixes()
marked = mark_missing_outside_prefixes_safely(all_prefixes)
if marked > 0:
logging.info("Marked %d refs as missing before scan", marked)
if self._check_pause_and_cancel():
logging.info("Asset scan cancelled after pruning phase")
cancelled = True
return
self._log_scan_config(roots)
# Phase 1: Fast scan (stub records)
if phase in (ScanPhase.FAST, ScanPhase.FULL):
created, skipped, paths = self._run_fast_phase(roots)
total_created, skipped_existing, total_paths = created, skipped, paths
if self._check_pause_and_cancel():
cancelled = True
return
self._emit_event(
"assets.seed.fast_complete",
{
"roots": list(roots),
"created": total_created,
"skipped": skipped_existing,
"total": total_paths,
},
)
# Phase 2: Enrichment scan (metadata + hashes)
if phase in (ScanPhase.ENRICH, ScanPhase.FULL):
if self._check_pause_and_cancel():
cancelled = True
return
enrich_cancelled, total_enriched = self._run_enrich_phase(roots)
if enrich_cancelled:
cancelled = True
return
self._emit_event(
"assets.seed.enrich_complete",
{
"roots": list(roots),
"enriched": total_enriched,
},
)
elapsed = time.perf_counter() - t_start
logging.info(
"Scan(%s, %s) done %.3fs: created=%d enriched=%d skipped=%d",
roots,
phase.value,
elapsed,
total_created,
total_enriched,
skipped_existing,
)
self._emit_event(
"assets.seed.completed",
{
"phase": phase.value,
"total": total_paths,
"created": total_created,
"enriched": total_enriched,
"skipped": skipped_existing,
"elapsed": round(elapsed, 3),
},
)
except Exception as e:
self._add_error(f"Scan failed: {e}")
logging.exception("Asset scan failed")
self._emit_event("assets.seed.error", {"message": str(e)})
finally:
if cancelled:
self._emit_event(
"assets.seed.cancelled",
{
"scanned": self._progress.scanned if self._progress else 0,
"total": total_paths,
"created": total_created,
},
)
with self._lock:
self._last_progress = self._progress
self._state = State.IDLE
self._progress = None
def _run_fast_phase(self, roots: tuple[RootType, ...]) -> tuple[int, int, int]:
"""Run phase 1: fast scan to create stub records.
Returns:
Tuple of (total_created, skipped_existing, total_paths)
"""
t_fast_start = time.perf_counter()
total_created = 0
skipped_existing = 0
existing_paths: set[str] = set()
t_sync = time.perf_counter()
for r in roots:
if self._check_pause_and_cancel():
return total_created, skipped_existing, 0
existing_paths.update(sync_root_safely(r))
logging.debug(
"Fast scan: sync_root phase took %.3fs (%d existing paths)",
time.perf_counter() - t_sync,
len(existing_paths),
)
if self._check_pause_and_cancel():
return total_created, skipped_existing, 0
t_collect = time.perf_counter()
paths = collect_paths_for_roots(roots)
logging.debug(
"Fast scan: collect_paths took %.3fs (%d paths found)",
time.perf_counter() - t_collect,
len(paths),
)
total_paths = len(paths)
self._update_progress(total=total_paths)
self._emit_event(
"assets.seed.started",
{"roots": list(roots), "total": total_paths, "phase": "fast"},
)
# Use stub specs (no metadata extraction, no hashing)
t_specs = time.perf_counter()
specs, tag_pool, skipped_existing = build_asset_specs(
paths,
existing_paths,
enable_metadata_extraction=False,
compute_hashes=False,
)
logging.debug(
"Fast scan: build_asset_specs took %.3fs (%d specs, %d skipped)",
time.perf_counter() - t_specs,
len(specs),
skipped_existing,
)
self._update_progress(skipped=skipped_existing)
if self._check_pause_and_cancel():
return total_created, skipped_existing, total_paths
batch_size = 500
last_progress_time = time.perf_counter()
progress_interval = 1.0
for i in range(0, len(specs), batch_size):
if self._check_pause_and_cancel():
logging.info(
"Fast scan cancelled after %d/%d files (created=%d)",
i,
len(specs),
total_created,
)
return total_created, skipped_existing, total_paths
batch = specs[i : i + batch_size]
batch_tags = {t for spec in batch for t in spec["tags"]}
try:
created = insert_asset_specs(batch, batch_tags)
total_created += created
except Exception as e:
self._add_error(f"Batch insert failed at offset {i}: {e}")
logging.exception("Batch insert failed at offset %d", i)
scanned = i + len(batch)
now = time.perf_counter()
self._update_progress(scanned=scanned, created=total_created)
if now - last_progress_time >= progress_interval:
self._emit_event(
"assets.seed.progress",
{
"phase": "fast",
"scanned": scanned,
"total": len(specs),
"created": total_created,
},
)
last_progress_time = now
self._update_progress(scanned=len(specs), created=total_created)
logging.info(
"Fast scan complete: %.3fs total (created=%d, skipped=%d, total_paths=%d)",
time.perf_counter() - t_fast_start,
total_created,
skipped_existing,
total_paths,
)
return total_created, skipped_existing, total_paths
def _run_enrich_phase(self, roots: tuple[RootType, ...]) -> tuple[bool, int]:
"""Run phase 2: enrich existing records with metadata and hashes.
Returns:
Tuple of (cancelled, total_enriched)
"""
total_enriched = 0
batch_size = 100
last_progress_time = time.perf_counter()
progress_interval = 1.0
# Get the target enrichment level based on compute_hashes
if not self._compute_hashes:
target_max_level = ENRICHMENT_STUB
else:
target_max_level = ENRICHMENT_METADATA
self._emit_event(
"assets.seed.started",
{"roots": list(roots), "phase": "enrich"},
)
skip_ids: set[str] = set()
consecutive_empty = 0
max_consecutive_empty = 3
# Hash checkpoints survive across batches so interrupted hashes
# can be resumed without re-reading the entire file.
hash_checkpoints: dict[str, object] = {}
while True:
if self._check_pause_and_cancel():
logging.info("Enrich scan cancelled after %d assets", total_enriched)
return True, total_enriched
# Fetch next batch of unenriched assets
unenriched = get_unenriched_assets_for_roots(
roots,
max_level=target_max_level,
limit=batch_size,
)
# Filter out previously failed references
if skip_ids:
unenriched = [r for r in unenriched if r.reference_id not in skip_ids]
if not unenriched:
break
enriched, failed_ids = enrich_assets_batch(
unenriched,
extract_metadata=True,
compute_hash=self._compute_hashes,
interrupt_check=self._is_paused_or_cancelled,
hash_checkpoints=hash_checkpoints,
)
total_enriched += enriched
skip_ids.update(failed_ids)
if enriched == 0:
consecutive_empty += 1
if consecutive_empty >= max_consecutive_empty:
logging.warning(
"Enrich phase stopping: %d consecutive batches with no progress (%d skipped)",
consecutive_empty,
len(skip_ids),
)
break
else:
consecutive_empty = 0
now = time.perf_counter()
if now - last_progress_time >= progress_interval:
self._emit_event(
"assets.seed.progress",
{
"phase": "enrich",
"enriched": total_enriched,
},
)
last_progress_time = now
return False, total_enriched
asset_seeder = _AssetSeeder()

View File

@ -0,0 +1,87 @@
from app.assets.services.asset_management import (
asset_exists,
delete_asset_reference,
get_asset_by_hash,
get_asset_detail,
list_assets_page,
resolve_asset_for_download,
set_asset_preview,
update_asset_metadata,
)
from app.assets.services.bulk_ingest import (
BulkInsertResult,
batch_insert_seed_assets,
cleanup_unreferenced_assets,
)
from app.assets.services.file_utils import (
get_mtime_ns,
get_size_and_mtime_ns,
list_files_recursively,
verify_file_unchanged,
)
from app.assets.services.ingest import (
DependencyMissingError,
HashMismatchError,
create_from_hash,
upload_from_temp_path,
)
from app.assets.database.queries import (
AddTagsResult,
RemoveTagsResult,
)
from app.assets.services.schemas import (
AssetData,
AssetDetailResult,
AssetSummaryData,
DownloadResolutionResult,
IngestResult,
ListAssetsResult,
ReferenceData,
RegisterAssetResult,
TagUsage,
UploadResult,
UserMetadata,
)
from app.assets.services.tagging import (
apply_tags,
list_tags,
remove_tags,
)
__all__ = [
"AddTagsResult",
"AssetData",
"AssetDetailResult",
"AssetSummaryData",
"ReferenceData",
"BulkInsertResult",
"DependencyMissingError",
"DownloadResolutionResult",
"HashMismatchError",
"IngestResult",
"ListAssetsResult",
"RegisterAssetResult",
"RemoveTagsResult",
"TagUsage",
"UploadResult",
"UserMetadata",
"apply_tags",
"asset_exists",
"batch_insert_seed_assets",
"create_from_hash",
"delete_asset_reference",
"get_asset_by_hash",
"get_asset_detail",
"get_mtime_ns",
"get_size_and_mtime_ns",
"list_assets_page",
"list_files_recursively",
"list_tags",
"cleanup_unreferenced_assets",
"remove_tags",
"resolve_asset_for_download",
"set_asset_preview",
"update_asset_metadata",
"upload_from_temp_path",
"verify_file_unchanged",
]

View File

@ -0,0 +1,309 @@
import contextlib
import mimetypes
import os
from typing import Sequence
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
reference_exists_for_asset_id,
delete_reference_by_id,
fetch_reference_and_asset,
soft_delete_reference_by_id,
fetch_reference_asset_and_tags,
get_asset_by_hash as queries_get_asset_by_hash,
get_reference_by_id,
get_reference_with_owner_check,
list_references_page,
list_references_by_asset_id,
set_reference_metadata,
set_reference_preview,
set_reference_tags,
update_reference_access_time,
update_reference_name,
update_reference_updated_at,
)
from app.assets.helpers import select_best_live_path
from app.assets.services.path_utils import compute_relative_filename
from app.assets.services.schemas import (
AssetData,
AssetDetailResult,
AssetSummaryData,
DownloadResolutionResult,
ListAssetsResult,
UserMetadata,
extract_asset_data,
extract_reference_data,
)
from app.database.db import create_session
def get_asset_detail(
reference_id: str,
owner_id: str = "",
) -> AssetDetailResult | None:
with create_session() as session:
result = fetch_reference_asset_and_tags(
session,
reference_id=reference_id,
owner_id=owner_id,
)
if not result:
return None
ref, asset, tags = result
return AssetDetailResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tags,
)
def update_asset_metadata(
reference_id: str,
name: str | None = None,
tags: Sequence[str] | None = None,
user_metadata: UserMetadata = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
ref = get_reference_with_owner_check(session, reference_id, owner_id)
touched = False
if name is not None and name != ref.name:
update_reference_name(session, reference_id=reference_id, name=name)
touched = True
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
new_meta: dict | None = None
if user_metadata is not None:
new_meta = dict(user_metadata)
elif computed_filename:
current_meta = ref.user_metadata or {}
if current_meta.get("filename") != computed_filename:
new_meta = dict(current_meta)
if new_meta is not None:
if computed_filename:
new_meta["filename"] = computed_filename
set_reference_metadata(
session, reference_id=reference_id, user_metadata=new_meta
)
touched = True
if tags is not None:
set_reference_tags(
session,
reference_id=reference_id,
tags=tags,
origin=tag_origin,
)
touched = True
if touched and user_metadata is None:
update_reference_updated_at(session, reference_id=reference_id)
result = fetch_reference_asset_and_tags(
session,
reference_id=reference_id,
owner_id=owner_id,
)
if not result:
raise RuntimeError("State changed during update")
ref, asset, tag_list = result
detail = AssetDetailResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_list,
)
session.commit()
return detail
def delete_asset_reference(
reference_id: str,
owner_id: str,
delete_content_if_orphan: bool = True,
) -> bool:
with create_session() as session:
if not delete_content_if_orphan:
# Soft delete: mark the reference as deleted but keep everything
deleted = soft_delete_reference_by_id(
session, reference_id=reference_id, owner_id=owner_id
)
session.commit()
return deleted
ref_row = get_reference_by_id(session, reference_id=reference_id)
asset_id = ref_row.asset_id if ref_row else None
file_path = ref_row.file_path if ref_row else None
deleted = delete_reference_by_id(
session, reference_id=reference_id, owner_id=owner_id
)
if not deleted:
session.commit()
return False
if not asset_id:
session.commit()
return True
still_exists = reference_exists_for_asset_id(session, asset_id=asset_id)
if still_exists:
session.commit()
return True
# Orphaned asset - delete it and its files
refs = list_references_by_asset_id(session, asset_id=asset_id)
file_paths = [
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
]
# Also include the just-deleted file path
if file_path:
file_paths.append(file_path)
asset_row = session.get(Asset, asset_id)
if asset_row is not None:
session.delete(asset_row)
session.commit()
# Delete files after commit
for p in file_paths:
with contextlib.suppress(Exception):
if p and os.path.isfile(p):
os.remove(p)
return True
def set_asset_preview(
reference_id: str,
preview_asset_id: str | None = None,
owner_id: str = "",
) -> AssetDetailResult:
with create_session() as session:
get_reference_with_owner_check(session, reference_id, owner_id)
set_reference_preview(
session,
reference_id=reference_id,
preview_asset_id=preview_asset_id,
)
result = fetch_reference_asset_and_tags(
session, reference_id=reference_id, owner_id=owner_id
)
if not result:
raise RuntimeError("State changed during preview update")
ref, asset, tags = result
detail = AssetDetailResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tags,
)
session.commit()
return detail
def asset_exists(asset_hash: str) -> bool:
with create_session() as session:
return asset_exists_by_hash(session, asset_hash=asset_hash)
def get_asset_by_hash(asset_hash: str) -> AssetData | None:
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash=asset_hash)
return extract_asset_data(asset)
def list_assets_page(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 20,
offset: int = 0,
sort: str = "created_at",
order: str = "desc",
) -> ListAssetsResult:
with create_session() as session:
refs, tag_map, total = list_references_page(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
offset=offset,
sort=sort,
order=order,
)
items: list[AssetSummaryData] = []
for ref in refs:
items.append(
AssetSummaryData(
ref=extract_reference_data(ref),
asset=extract_asset_data(ref.asset),
tags=tag_map.get(ref.id, []),
)
)
return ListAssetsResult(items=items, total=total)
def resolve_asset_for_download(
reference_id: str,
owner_id: str = "",
) -> DownloadResolutionResult:
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise ValueError(f"AssetReference {reference_id} not found")
ref, asset = pair
# For references with file_path, use that directly
if ref.file_path and os.path.isfile(ref.file_path):
abs_path = ref.file_path
else:
# For API-created refs without file_path, find a path from other refs
refs = list_references_by_asset_id(session, asset_id=asset.id)
abs_path = select_best_live_path(refs)
if not abs_path:
raise FileNotFoundError(
f"No live path for AssetReference {reference_id} "
f"(asset id={asset.id}, name={ref.name})"
)
# Capture ORM attributes before commit (commit expires loaded objects)
ref_name = ref.name
asset_mime = asset.mime_type
update_reference_access_time(session, reference_id=reference_id)
session.commit()
ctype = (
asset_mime
or mimetypes.guess_type(ref_name or abs_path)[0]
or "application/octet-stream"
)
download_name = ref_name or os.path.basename(abs_path)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=download_name,
)

View File

@ -0,0 +1,280 @@
from __future__ import annotations
import os
import uuid
from dataclasses import dataclass
from datetime import datetime
from typing import TYPE_CHECKING, Any, TypedDict
from sqlalchemy.orm import Session
from app.assets.database.queries import (
bulk_insert_assets,
bulk_insert_references_ignore_conflicts,
bulk_insert_tags_and_meta,
delete_assets_by_ids,
get_existing_asset_ids,
get_reference_ids_by_ids,
get_references_by_paths_and_asset_ids,
get_unreferenced_unhashed_asset_ids,
restore_references_by_paths,
)
from app.assets.helpers import get_utc_now
if TYPE_CHECKING:
from app.assets.services.metadata_extract import ExtractedMetadata
class SeedAssetSpec(TypedDict):
"""Spec for seeding an asset from filesystem."""
abs_path: str
size_bytes: int
mtime_ns: int
info_name: str
tags: list[str]
fname: str
metadata: ExtractedMetadata | None
hash: str | None
mime_type: str | None
class AssetRow(TypedDict):
"""Row data for inserting an Asset."""
id: str
hash: str | None
size_bytes: int
mime_type: str | None
created_at: datetime
class ReferenceRow(TypedDict):
"""Row data for inserting an AssetReference."""
id: str
asset_id: str
file_path: str
mtime_ns: int
owner_id: str
name: str
preview_id: str | None
user_metadata: dict[str, Any] | None
created_at: datetime
updated_at: datetime
last_access_time: datetime
class TagRow(TypedDict):
"""Row data for inserting a Tag."""
asset_reference_id: str
tag_name: str
origin: str
added_at: datetime
class MetadataRow(TypedDict):
"""Row data for inserting asset metadata."""
asset_reference_id: str
key: str
ordinal: int
val_str: str | None
val_num: float | None
val_bool: bool | None
val_json: dict[str, Any] | None
@dataclass
class BulkInsertResult:
"""Result of bulk asset insertion."""
inserted_refs: int
won_paths: int
lost_paths: int
def batch_insert_seed_assets(
session: Session,
specs: list[SeedAssetSpec],
owner_id: str = "",
) -> BulkInsertResult:
"""Seed assets from filesystem specs in batch.
Each spec is a dict with keys:
- abs_path: str
- size_bytes: int
- mtime_ns: int
- info_name: str
- tags: list[str]
- fname: Optional[str]
This function orchestrates:
1. Insert seed Assets (hash=NULL)
2. Claim references with ON CONFLICT DO NOTHING on file_path
3. Query to find winners (paths where our asset_id was inserted)
4. Delete Assets for losers (path already claimed by another asset)
5. Insert tags and metadata for successfully inserted references
Returns:
BulkInsertResult with inserted_refs, won_paths, lost_paths
"""
if not specs:
return BulkInsertResult(inserted_refs=0, won_paths=0, lost_paths=0)
current_time = get_utc_now()
asset_rows: list[AssetRow] = []
reference_rows: list[ReferenceRow] = []
path_to_asset_id: dict[str, str] = {}
asset_id_to_ref_data: dict[str, dict] = {}
absolute_path_list: list[str] = []
for spec in specs:
absolute_path = os.path.abspath(spec["abs_path"])
asset_id = str(uuid.uuid4())
reference_id = str(uuid.uuid4())
absolute_path_list.append(absolute_path)
path_to_asset_id[absolute_path] = asset_id
mime_type = spec.get("mime_type")
asset_rows.append(
{
"id": asset_id,
"hash": spec.get("hash"),
"size_bytes": spec["size_bytes"],
"mime_type": mime_type,
"created_at": current_time,
}
)
# Build user_metadata from extracted metadata or fallback to filename
extracted_metadata = spec.get("metadata")
if extracted_metadata:
user_metadata: dict[str, Any] | None = extracted_metadata.to_user_metadata()
elif spec["fname"]:
user_metadata = {"filename": spec["fname"]}
else:
user_metadata = None
reference_rows.append(
{
"id": reference_id,
"asset_id": asset_id,
"file_path": absolute_path,
"mtime_ns": spec["mtime_ns"],
"owner_id": owner_id,
"name": spec["info_name"],
"preview_id": None,
"user_metadata": user_metadata,
"created_at": current_time,
"updated_at": current_time,
"last_access_time": current_time,
}
)
asset_id_to_ref_data[asset_id] = {
"reference_id": reference_id,
"tags": spec["tags"],
"filename": spec["fname"],
"extracted_metadata": extracted_metadata,
}
bulk_insert_assets(session, asset_rows)
# Filter reference rows to only those whose assets were actually inserted
# (assets with duplicate hashes are silently dropped by ON CONFLICT DO NOTHING)
inserted_asset_ids = get_existing_asset_ids(
session, [r["asset_id"] for r in reference_rows]
)
reference_rows = [r for r in reference_rows if r["asset_id"] in inserted_asset_ids]
bulk_insert_references_ignore_conflicts(session, reference_rows)
restore_references_by_paths(session, absolute_path_list)
winning_paths = get_references_by_paths_and_asset_ids(session, path_to_asset_id)
inserted_paths = {
path
for path in absolute_path_list
if path_to_asset_id[path] in inserted_asset_ids
}
losing_paths = inserted_paths - winning_paths
lost_asset_ids = [path_to_asset_id[path] for path in losing_paths]
if lost_asset_ids:
delete_assets_by_ids(session, lost_asset_ids)
if not winning_paths:
return BulkInsertResult(
inserted_refs=0,
won_paths=0,
lost_paths=len(losing_paths),
)
# Get reference IDs for winners
winning_ref_ids = [
asset_id_to_ref_data[path_to_asset_id[path]]["reference_id"]
for path in winning_paths
]
inserted_ref_ids = get_reference_ids_by_ids(session, winning_ref_ids)
tag_rows: list[TagRow] = []
metadata_rows: list[MetadataRow] = []
if inserted_ref_ids:
for path in winning_paths:
asset_id = path_to_asset_id[path]
ref_data = asset_id_to_ref_data[asset_id]
ref_id = ref_data["reference_id"]
if ref_id not in inserted_ref_ids:
continue
for tag in ref_data["tags"]:
tag_rows.append(
{
"asset_reference_id": ref_id,
"tag_name": tag,
"origin": "automatic",
"added_at": current_time,
}
)
# Use extracted metadata for meta rows if available
extracted_metadata = ref_data.get("extracted_metadata")
if extracted_metadata:
metadata_rows.extend(extracted_metadata.to_meta_rows(ref_id))
elif ref_data["filename"]:
# Fallback: just store filename
metadata_rows.append(
{
"asset_reference_id": ref_id,
"key": "filename",
"ordinal": 0,
"val_str": ref_data["filename"],
"val_num": None,
"val_bool": None,
"val_json": None,
}
)
bulk_insert_tags_and_meta(session, tag_rows=tag_rows, meta_rows=metadata_rows)
return BulkInsertResult(
inserted_refs=len(inserted_ref_ids),
won_paths=len(winning_paths),
lost_paths=len(losing_paths),
)
def cleanup_unreferenced_assets(session: Session) -> int:
"""Hard-delete unhashed assets with no active references.
This is a destructive operation intended for explicit cleanup.
Only deletes assets where hash=None and all references are missing.
Returns:
Number of assets deleted
"""
unreferenced_ids = get_unreferenced_unhashed_asset_ids(session)
return delete_assets_by_ids(session, unreferenced_ids)

View File

@ -0,0 +1,70 @@
import os
def get_mtime_ns(stat_result: os.stat_result) -> int:
"""Extract mtime in nanoseconds from a stat result."""
return getattr(
stat_result, "st_mtime_ns", int(stat_result.st_mtime * 1_000_000_000)
)
def get_size_and_mtime_ns(path: str, follow_symlinks: bool = True) -> tuple[int, int]:
"""Get file size in bytes and mtime in nanoseconds."""
st = os.stat(path, follow_symlinks=follow_symlinks)
return st.st_size, get_mtime_ns(st)
def verify_file_unchanged(
mtime_db: int | None,
size_db: int | None,
stat_result: os.stat_result,
) -> bool:
"""Check if a file is unchanged based on mtime and size.
Returns True if the file's mtime and size match the database values.
Returns False if mtime_db is None or values don't match.
size_db=None means don't check size; 0 is a valid recorded size.
"""
if mtime_db is None:
return False
actual_mtime_ns = get_mtime_ns(stat_result)
if int(mtime_db) != int(actual_mtime_ns):
return False
if size_db is not None:
return int(stat_result.st_size) == int(size_db)
return True
def is_visible(name: str) -> bool:
"""Return True if a file or directory name is visible (not hidden)."""
return not name.startswith(".")
def list_files_recursively(base_dir: str) -> list[str]:
"""Recursively list all files in a directory, following symlinks."""
out: list[str] = []
base_abs = os.path.abspath(base_dir)
if not os.path.isdir(base_abs):
return out
# Track seen real directory identities to prevent circular symlink loops
seen_dirs: set[tuple[int, int]] = set()
for dirpath, subdirs, filenames in os.walk(
base_abs, topdown=True, followlinks=True
):
try:
st = os.stat(dirpath)
dir_id = (st.st_dev, st.st_ino)
except OSError:
subdirs.clear()
continue
if dir_id in seen_dirs:
subdirs.clear()
continue
seen_dirs.add(dir_id)
subdirs[:] = [d for d in subdirs if is_visible(d)]
for name in filenames:
if not is_visible(name):
continue
out.append(os.path.abspath(os.path.join(dirpath, name)))
return out

View File

@ -0,0 +1,99 @@
import io
import os
from contextlib import contextmanager
from dataclasses import dataclass
from typing import IO, Any, Callable, Iterator
import logging
try:
from blake3 import blake3
except ModuleNotFoundError:
logging.warning("WARNING: blake3 package not installed")
DEFAULT_CHUNK = 8 * 1024 * 1024
InterruptCheck = Callable[[], bool]
@dataclass
class HashCheckpoint:
"""Saved state for resuming an interrupted hash computation."""
bytes_processed: int
hasher: Any # blake3 hasher instance
mtime_ns: int = 0
file_size: int = 0
@contextmanager
def _open_for_hashing(fp: str | IO[bytes]) -> Iterator[tuple[IO[bytes], bool]]:
"""Yield (file_object, is_path) with appropriate setup/teardown."""
if hasattr(fp, "read"):
seekable = getattr(fp, "seekable", lambda: False)()
orig_pos = None
if seekable:
try:
orig_pos = fp.tell()
if orig_pos != 0:
fp.seek(0)
except io.UnsupportedOperation:
orig_pos = None
try:
yield fp, False
finally:
if orig_pos is not None:
fp.seek(orig_pos)
else:
with open(os.fspath(fp), "rb") as f:
yield f, True
def compute_blake3_hash(
fp: str | IO[bytes],
chunk_size: int = DEFAULT_CHUNK,
interrupt_check: InterruptCheck | None = None,
checkpoint: HashCheckpoint | None = None,
) -> tuple[str | None, HashCheckpoint | None]:
"""Compute BLAKE3 hash of a file, with optional checkpoint support.
Args:
fp: File path or file-like object
chunk_size: Size of chunks to read at a time
interrupt_check: Optional callable that returns True if the operation
should be interrupted (e.g. paused or cancelled). Must be
non-blocking so file handles are released immediately. Checked
between chunk reads.
checkpoint: Optional checkpoint to resume from (file paths only)
Returns:
Tuple of (hex_digest, None) on completion, or
(None, checkpoint) on interruption (file paths only), or
(None, None) on interruption of a file object
"""
if chunk_size <= 0:
chunk_size = DEFAULT_CHUNK
with _open_for_hashing(fp) as (f, is_path):
if checkpoint is not None and is_path:
f.seek(checkpoint.bytes_processed)
h = checkpoint.hasher
bytes_processed = checkpoint.bytes_processed
else:
h = blake3()
bytes_processed = 0
while True:
if interrupt_check is not None and interrupt_check():
if is_path:
return None, HashCheckpoint(
bytes_processed=bytes_processed,
hasher=h,
)
return None, None
chunk = f.read(chunk_size)
if not chunk:
break
h.update(chunk)
bytes_processed += len(chunk)
return h.hexdigest(), None

View File

@ -0,0 +1,375 @@
import contextlib
import logging
import mimetypes
import os
from typing import Any, Sequence
from sqlalchemy.orm import Session
import app.assets.services.hashing as hashing
from app.assets.database.queries import (
add_tags_to_reference,
fetch_reference_and_asset,
get_asset_by_hash,
get_existing_asset_ids,
get_reference_by_file_path,
get_reference_tags,
get_or_create_reference,
remove_missing_tag_for_asset_id,
set_reference_metadata,
set_reference_tags,
upsert_asset,
upsert_reference,
validate_tags_exist,
)
from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import (
compute_relative_filename,
resolve_destination_from_tags,
validate_path_within_base,
)
from app.assets.services.schemas import (
IngestResult,
RegisterAssetResult,
UploadResult,
UserMetadata,
extract_asset_data,
extract_reference_data,
)
from app.database.db import create_session
def _ingest_file_from_path(
abs_path: str,
asset_hash: str,
size_bytes: int,
mtime_ns: int,
mime_type: str | None = None,
info_name: str | None = None,
owner_id: str = "",
preview_id: str | None = None,
user_metadata: UserMetadata = None,
tags: Sequence[str] = (),
tag_origin: str = "manual",
require_existing_tags: bool = False,
) -> IngestResult:
locator = os.path.abspath(abs_path)
user_metadata = user_metadata or {}
asset_created = False
asset_updated = False
ref_created = False
ref_updated = False
reference_id: str | None = None
with create_session() as session:
if preview_id:
if preview_id not in get_existing_asset_ids(session, [preview_id]):
preview_id = None
asset, asset_created, asset_updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=size_bytes,
mime_type=mime_type,
)
ref_created, ref_updated = upsert_reference(
session,
asset_id=asset.id,
file_path=locator,
name=info_name or os.path.basename(locator),
mtime_ns=mtime_ns,
owner_id=owner_id,
)
# Get the reference we just created/updated
ref = get_reference_by_file_path(session, locator)
if ref:
reference_id = ref.id
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
norm = normalize_tags(list(tags))
if norm:
if require_existing_tags:
validate_tags_exist(session, norm)
add_tags_to_reference(
session,
reference_id=reference_id,
tags=norm,
origin=tag_origin,
create_if_missing=not require_existing_tags,
)
_update_metadata_with_filename(
session,
reference_id=reference_id,
file_path=ref.file_path,
current_metadata=ref.user_metadata,
user_metadata=user_metadata,
)
try:
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
except Exception:
logging.exception("Failed to clear 'missing' tag for asset %s", asset.id)
session.commit()
return IngestResult(
asset_created=asset_created,
asset_updated=asset_updated,
ref_created=ref_created,
ref_updated=ref_updated,
reference_id=reference_id,
)
def _register_existing_asset(
asset_hash: str,
name: str,
user_metadata: UserMetadata = None,
tags: list[str] | None = None,
tag_origin: str = "manual",
owner_id: str = "",
) -> RegisterAssetResult:
user_metadata = user_metadata or {}
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=asset_hash)
if not asset:
raise ValueError(f"No asset with hash {asset_hash}")
ref, ref_created = get_or_create_reference(
session,
asset_id=asset.id,
owner_id=owner_id,
name=name,
)
if not ref_created:
tag_names = get_reference_tags(session, reference_id=ref.id)
result = RegisterAssetResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created=False,
)
session.commit()
return result
new_meta = dict(user_metadata)
computed_filename = compute_relative_filename(ref.file_path) if ref.file_path else None
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta:
set_reference_metadata(
session,
reference_id=ref.id,
user_metadata=new_meta,
)
if tags is not None:
set_reference_tags(
session,
reference_id=ref.id,
tags=tags,
origin=tag_origin,
)
tag_names = get_reference_tags(session, reference_id=ref.id)
session.refresh(ref)
result = RegisterAssetResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created=True,
)
session.commit()
return result
def _update_metadata_with_filename(
session: Session,
reference_id: str,
file_path: str | None,
current_metadata: dict | None,
user_metadata: dict[str, Any],
) -> None:
computed_filename = compute_relative_filename(file_path) if file_path else None
current_meta = current_metadata or {}
new_meta = dict(current_meta)
for k, v in user_metadata.items():
new_meta[k] = v
if computed_filename:
new_meta["filename"] = computed_filename
if new_meta != current_meta:
set_reference_metadata(
session,
reference_id=reference_id,
user_metadata=new_meta,
)
def _sanitize_filename(name: str | None, fallback: str) -> str:
n = os.path.basename((name or "").strip() or fallback)
return n if n else fallback
class HashMismatchError(Exception):
pass
class DependencyMissingError(Exception):
def __init__(self, message: str):
self.message = message
super().__init__(message)
def upload_from_temp_path(
temp_path: str,
name: str | None = None,
tags: list[str] | None = None,
user_metadata: dict | None = None,
client_filename: str | None = None,
owner_id: str = "",
expected_hash: str | None = None,
) -> UploadResult:
try:
digest, _ = hashing.compute_blake3_hash(temp_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash uploaded file: {e}")
asset_hash = "blake3:" + digest
if expected_hash and asset_hash != expected_hash.strip().lower():
raise HashMismatchError("Uploaded file hash does not match provided hash.")
with create_session() as session:
existing = get_asset_by_hash(session, asset_hash=asset_hash)
if existing is not None:
with contextlib.suppress(Exception):
if temp_path and os.path.exists(temp_path):
os.remove(temp_path)
display_name = _sanitize_filename(name or client_filename, fallback=digest)
result = _register_existing_asset(
asset_hash=asset_hash,
name=display_name,
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
return UploadResult(
ref=result.ref,
asset=result.asset,
tags=result.tags,
created_new=False,
)
if not tags:
raise ValueError("tags are required for new asset uploads")
base_dir, subdirs = resolve_destination_from_tags(tags)
dest_dir = os.path.join(base_dir, *subdirs) if subdirs else base_dir
os.makedirs(dest_dir, exist_ok=True)
src_for_ext = (client_filename or name or "").strip()
_ext = os.path.splitext(os.path.basename(src_for_ext))[1] if src_for_ext else ""
ext = _ext if 0 < len(_ext) <= 16 else ""
hashed_basename = f"{digest}{ext}"
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir)
content_type = (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream"
)
try:
os.replace(temp_path, dest_abs)
except Exception as e:
raise RuntimeError(f"failed to move uploaded file into place: {e}")
try:
size_bytes, mtime_ns = get_size_and_mtime_ns(dest_abs)
except OSError as e:
raise RuntimeError(f"failed to stat destination file: {e}")
ingest_result = _ingest_file_from_path(
asset_hash=asset_hash,
abs_path=dest_abs,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id,
preview_id=None,
user_metadata=user_metadata or {},
tags=tags,
tag_origin="manual",
require_existing_tags=False,
)
reference_id = ingest_result.reference_id
if not reference_id:
raise RuntimeError("failed to create asset reference")
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
ref, asset = pair
tag_names = get_reference_tags(session, reference_id=ref.id)
return UploadResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def create_from_hash(
hash_str: str,
name: str,
tags: list[str] | None = None,
user_metadata: dict | None = None,
owner_id: str = "",
) -> UploadResult | None:
canonical = hash_str.strip().lower()
with create_session() as session:
asset = get_asset_by_hash(session, asset_hash=canonical)
if not asset:
return None
result = _register_existing_asset(
asset_hash=canonical,
name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
),
user_metadata=user_metadata or {},
tags=tags or [],
tag_origin="manual",
owner_id=owner_id,
)
return UploadResult(
ref=result.ref,
asset=result.asset,
tags=result.tags,
created_new=False,
)

View File

@ -0,0 +1,327 @@
"""Metadata extraction for asset scanning.
Tier 1: Filesystem metadata (zero parsing)
Tier 2: Safetensors header metadata (fast JSON read only)
"""
from __future__ import annotations
import json
import logging
import mimetypes
import os
import struct
from dataclasses import dataclass
from typing import Any
from utils.mime_types import init_mime_types
init_mime_types()
# Supported safetensors extensions
SAFETENSORS_EXTENSIONS = frozenset({".safetensors", ".sft"})
# Maximum safetensors header size to read (8MB)
MAX_SAFETENSORS_HEADER_SIZE = 8 * 1024 * 1024
@dataclass
class ExtractedMetadata:
"""Metadata extracted from a file during scanning."""
# Tier 1: Filesystem (always available)
filename: str = ""
file_path: str = "" # Full absolute path to the file
content_length: int = 0
content_type: str | None = None
format: str = "" # file extension without dot
# Tier 2: Safetensors header (if available)
base_model: str | None = None
trained_words: list[str] | None = None
air: str | None = None # CivitAI AIR identifier
has_preview_images: bool = False
# Source provenance (populated if embedded in safetensors)
source_url: str | None = None
source_arn: str | None = None
repo_url: str | None = None
preview_url: str | None = None
source_hash: str | None = None
# HuggingFace specific
repo_id: str | None = None
revision: str | None = None
filepath: str | None = None
resolve_url: str | None = None
def to_user_metadata(self) -> dict[str, Any]:
"""Convert to user_metadata dict for AssetReference.user_metadata JSON field."""
data: dict[str, Any] = {
"filename": self.filename,
"content_length": self.content_length,
"format": self.format,
}
if self.file_path:
data["file_path"] = self.file_path
if self.content_type:
data["content_type"] = self.content_type
# Tier 2 fields
if self.base_model:
data["base_model"] = self.base_model
if self.trained_words:
data["trained_words"] = self.trained_words
if self.air:
data["air"] = self.air
if self.has_preview_images:
data["has_preview_images"] = True
# Source provenance
if self.source_url:
data["source_url"] = self.source_url
if self.source_arn:
data["source_arn"] = self.source_arn
if self.repo_url:
data["repo_url"] = self.repo_url
if self.preview_url:
data["preview_url"] = self.preview_url
if self.source_hash:
data["source_hash"] = self.source_hash
# HuggingFace
if self.repo_id:
data["repo_id"] = self.repo_id
if self.revision:
data["revision"] = self.revision
if self.filepath:
data["filepath"] = self.filepath
if self.resolve_url:
data["resolve_url"] = self.resolve_url
return data
def to_meta_rows(self, reference_id: str) -> list[dict]:
"""Convert to asset_reference_meta rows for typed/indexed querying."""
rows: list[dict] = []
def add_str(key: str, val: str | None, ordinal: int = 0) -> None:
if val:
rows.append({
"asset_reference_id": reference_id,
"key": key,
"ordinal": ordinal,
"val_str": val[:2048] if len(val) > 2048 else val,
"val_num": None,
"val_bool": None,
"val_json": None,
})
def add_num(key: str, val: int | float | None) -> None:
if val is not None:
rows.append({
"asset_reference_id": reference_id,
"key": key,
"ordinal": 0,
"val_str": None,
"val_num": val,
"val_bool": None,
"val_json": None,
})
def add_bool(key: str, val: bool | None) -> None:
if val is not None:
rows.append({
"asset_reference_id": reference_id,
"key": key,
"ordinal": 0,
"val_str": None,
"val_num": None,
"val_bool": val,
"val_json": None,
})
# Tier 1
add_str("filename", self.filename)
add_num("content_length", self.content_length)
add_str("content_type", self.content_type)
add_str("format", self.format)
# Tier 2
add_str("base_model", self.base_model)
add_str("air", self.air)
has_previews = self.has_preview_images if self.has_preview_images else None
add_bool("has_preview_images", has_previews)
# trained_words as multiple rows with ordinals
if self.trained_words:
for i, word in enumerate(self.trained_words[:100]): # limit to 100 words
add_str("trained_words", word, ordinal=i)
# Source provenance
add_str("source_url", self.source_url)
add_str("source_arn", self.source_arn)
add_str("repo_url", self.repo_url)
add_str("preview_url", self.preview_url)
add_str("source_hash", self.source_hash)
# HuggingFace
add_str("repo_id", self.repo_id)
add_str("revision", self.revision)
add_str("filepath", self.filepath)
add_str("resolve_url", self.resolve_url)
return rows
def _read_safetensors_header(
path: str, max_size: int = MAX_SAFETENSORS_HEADER_SIZE
) -> dict[str, Any] | None:
"""Read only the JSON header from a safetensors file.
This is very fast - reads 8 bytes for header length, then the JSON header.
No tensor data is loaded.
Args:
path: Absolute path to safetensors file
max_size: Maximum header size to read (default 8MB)
Returns:
Parsed header dict or None if failed
"""
try:
with open(path, "rb") as f:
header_bytes = f.read(8)
if len(header_bytes) < 8:
return None
length_of_header = struct.unpack("<Q", header_bytes)[0]
if length_of_header > max_size:
return None
header_data = f.read(length_of_header)
if len(header_data) < length_of_header:
return None
return json.loads(header_data.decode("utf-8"))
except (OSError, json.JSONDecodeError, UnicodeDecodeError, struct.error):
return None
def _extract_safetensors_metadata(
header: dict[str, Any], meta: ExtractedMetadata
) -> None:
"""Extract metadata from safetensors header __metadata__ section.
Modifies meta in-place.
"""
st_meta = header.get("__metadata__", {})
if not isinstance(st_meta, dict):
return
# Common model metadata
meta.base_model = (
st_meta.get("ss_base_model_version")
or st_meta.get("modelspec.base_model")
or st_meta.get("base_model")
)
# Trained words / trigger words
trained_words = st_meta.get("ss_tag_frequency")
if trained_words and isinstance(trained_words, str):
try:
tag_freq = json.loads(trained_words)
# Extract unique tags from all datasets
all_tags: set[str] = set()
for dataset_tags in tag_freq.values():
if isinstance(dataset_tags, dict):
all_tags.update(dataset_tags.keys())
if all_tags:
meta.trained_words = sorted(all_tags)[:100]
except json.JSONDecodeError:
pass
# Direct trained_words field (some formats)
if not meta.trained_words:
tw = st_meta.get("trained_words")
if isinstance(tw, str):
try:
parsed = json.loads(tw)
if isinstance(parsed, list):
meta.trained_words = [str(x) for x in parsed]
else:
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
except json.JSONDecodeError:
meta.trained_words = [w.strip() for w in tw.split(",") if w.strip()]
elif isinstance(tw, list):
meta.trained_words = [str(x) for x in tw]
# CivitAI AIR
meta.air = st_meta.get("air") or st_meta.get("modelspec.air")
# Preview images (ssmd_cover_images)
cover_images = st_meta.get("ssmd_cover_images")
if cover_images:
meta.has_preview_images = True
# Source provenance fields
meta.source_url = st_meta.get("source_url")
meta.source_arn = st_meta.get("source_arn")
meta.repo_url = st_meta.get("repo_url")
meta.preview_url = st_meta.get("preview_url")
meta.source_hash = st_meta.get("source_hash") or st_meta.get("sshs_model_hash")
# HuggingFace fields
meta.repo_id = st_meta.get("repo_id") or st_meta.get("hf_repo_id")
meta.revision = st_meta.get("revision") or st_meta.get("hf_revision")
meta.filepath = st_meta.get("filepath") or st_meta.get("hf_filepath")
meta.resolve_url = st_meta.get("resolve_url") or st_meta.get("hf_url")
def extract_file_metadata(
abs_path: str,
stat_result: os.stat_result | None = None,
relative_filename: str | None = None,
) -> ExtractedMetadata:
"""Extract metadata from a file using tier 1 and tier 2 methods.
Tier 1: Filesystem metadata from path and stat
Tier 2: Safetensors header parsing if applicable
Args:
abs_path: Absolute path to the file
stat_result: Optional pre-fetched stat result (saves a syscall)
relative_filename: Optional relative filename to use instead of basename
(e.g., "flux/123/model.safetensors" for model paths)
Returns:
ExtractedMetadata with all available fields populated
"""
meta = ExtractedMetadata()
# Tier 1: Filesystem metadata
meta.filename = relative_filename or os.path.basename(abs_path)
meta.file_path = abs_path
_, ext = os.path.splitext(abs_path)
meta.format = ext.lstrip(".").lower() if ext else ""
mime_type, _ = mimetypes.guess_type(abs_path)
meta.content_type = mime_type
# Size from stat
if stat_result is None:
try:
stat_result = os.stat(abs_path, follow_symlinks=True)
except OSError:
pass
if stat_result:
meta.content_length = stat_result.st_size
# Tier 2: Safetensors header (if applicable and enabled)
if ext.lower() in SAFETENSORS_EXTENSIONS:
header = _read_safetensors_header(abs_path)
if header:
try:
_extract_safetensors_metadata(header, meta)
except Exception as e:
logging.debug("Safetensors meta extract failed %s: %s", abs_path, e)
return meta

View File

@ -0,0 +1,167 @@
import os
from pathlib import Path
from typing import Literal
import folder_paths
from app.assets.helpers import normalize_tags
_NON_MODEL_FOLDER_NAMES = frozenset({"custom_nodes"})
def get_comfy_models_folders() -> list[tuple[str, list[str]]]:
"""Build list of (folder_name, base_paths[]) for all model locations.
Includes every category registered in folder_names_and_paths,
regardless of whether its paths are under the main models_dir,
but excludes non-model entries like custom_nodes.
"""
targets: list[tuple[str, list[str]]] = []
for name, values in folder_paths.folder_names_and_paths.items():
if name in _NON_MODEL_FOLDER_NAMES:
continue
paths, _exts = values[0], values[1]
if paths:
targets.append((name, paths))
return targets
def resolve_destination_from_tags(tags: list[str]) -> tuple[str, list[str]]:
"""Validates and maps tags -> (base_dir, subdirs_for_fs)"""
if not tags:
raise ValueError("tags must not be empty")
root = tags[0].lower()
if root == "models":
if len(tags) < 2:
raise ValueError("at least two tags required for model asset")
try:
bases = folder_paths.folder_names_and_paths[tags[1]][0]
except KeyError:
raise ValueError(f"unknown model category '{tags[1]}'")
if not bases:
raise ValueError(f"no base path configured for category '{tags[1]}'")
base_dir = os.path.abspath(bases[0])
raw_subdirs = tags[2:]
elif root == "input":
base_dir = os.path.abspath(folder_paths.get_input_directory())
raw_subdirs = tags[1:]
elif root == "output":
base_dir = os.path.abspath(folder_paths.get_output_directory())
raw_subdirs = tags[1:]
else:
raise ValueError(f"unknown root tag '{tags[0]}'; expected 'models', 'input', or 'output'")
_sep_chars = frozenset(("/", "\\", os.sep))
for i in raw_subdirs:
if i in (".", "..") or _sep_chars & set(i):
raise ValueError("invalid path component in tags")
return base_dir, raw_subdirs if raw_subdirs else []
def validate_path_within_base(candidate: str, base: str) -> None:
cand_abs = Path(os.path.abspath(candidate))
base_abs = Path(os.path.abspath(base))
if not cand_abs.is_relative_to(base_abs):
raise ValueError("destination escapes base directory")
def compute_relative_filename(file_path: str) -> str | None:
"""
Return the model's path relative to the last well-known folder (the model category),
using forward slashes, eg:
/.../models/checkpoints/flux/123/flux.safetensors -> "flux/123/flux.safetensors"
/.../models/text_encoders/clip_g.safetensors -> "clip_g.safetensors"
For non-model paths, returns None.
"""
try:
root_category, rel_path = get_asset_category_and_relative_path(file_path)
except ValueError:
return None
p = Path(rel_path)
parts = [seg for seg in p.parts if seg not in (".", "..", p.anchor)]
if not parts:
return None
if root_category == "models":
# parts[0] is the category ("checkpoints", "vae", etc) drop it
inside = parts[1:] if len(parts) > 1 else [parts[0]]
return "/".join(inside)
return "/".join(parts) # input/output: keep all parts
def get_asset_category_and_relative_path(
file_path: str,
) -> tuple[Literal["input", "output", "models"], str]:
"""Determine which root category a file path belongs to.
Categories:
- 'input': under folder_paths.get_input_directory()
- 'output': under folder_paths.get_output_directory()
- 'models': under any base path from get_comfy_models_folders()
Returns:
(root_category, relative_path_inside_that_root)
Raises:
ValueError: path does not belong to any known root.
"""
fp_abs = os.path.abspath(file_path)
def _check_is_within(child: str, parent: str) -> bool:
return Path(child).is_relative_to(parent)
def _compute_relative(child: str, parent: str) -> str:
# Normalize relative path, stripping any leading ".." components
# by anchoring to root (os.sep) then computing relpath back from it.
return os.path.relpath(
os.path.join(os.sep, os.path.relpath(child, parent)), os.sep
)
# 1) input
input_base = os.path.abspath(folder_paths.get_input_directory())
if _check_is_within(fp_abs, input_base):
return "input", _compute_relative(fp_abs, input_base)
# 2) output
output_base = os.path.abspath(folder_paths.get_output_directory())
if _check_is_within(fp_abs, output_base):
return "output", _compute_relative(fp_abs, output_base)
# 3) models (check deepest matching base to avoid ambiguity)
best: tuple[int, str, str] | None = None # (base_len, bucket, rel_inside_bucket)
for bucket, bases in get_comfy_models_folders():
for b in bases:
base_abs = os.path.abspath(b)
if not _check_is_within(fp_abs, base_abs):
continue
cand = (len(base_abs), bucket, _compute_relative(fp_abs, base_abs))
if best is None or cand[0] > best[0]:
best = cand
if best is not None:
_, bucket, rel_inside = best
combined = os.path.join(bucket, rel_inside)
return "models", os.path.relpath(os.path.join(os.sep, combined), os.sep)
raise ValueError(
f"Path is not within input, output, or configured model bases: {file_path}"
)
def get_name_and_tags_from_asset_path(file_path: str) -> tuple[str, list[str]]:
"""Return (name, tags) derived from a filesystem path.
- name: base filename with extension
- tags: [root_category] + parent folder names in order
Raises:
ValueError: path does not belong to any known root.
"""
root_category, some_path = get_asset_category_and_relative_path(file_path)
p = Path(some_path)
parent_parts = [
part for part in p.parent.parts if part not in (".", "..", p.anchor)
]
return p.name, list(dict.fromkeys(normalize_tags([root_category, *parent_parts])))

View File

@ -0,0 +1,109 @@
from dataclasses import dataclass
from datetime import datetime
from typing import Any, NamedTuple
from app.assets.database.models import Asset, AssetReference
UserMetadata = dict[str, Any] | None
@dataclass(frozen=True)
class AssetData:
hash: str | None
size_bytes: int | None
mime_type: str | None
@dataclass(frozen=True)
class ReferenceData:
"""Data transfer object for AssetReference."""
id: str
name: str
file_path: str | None
user_metadata: UserMetadata
preview_id: str | None
created_at: datetime
updated_at: datetime
last_access_time: datetime | None
@dataclass(frozen=True)
class AssetDetailResult:
ref: ReferenceData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class RegisterAssetResult:
ref: ReferenceData
asset: AssetData
tags: list[str]
created: bool
@dataclass(frozen=True)
class IngestResult:
asset_created: bool
asset_updated: bool
ref_created: bool
ref_updated: bool
reference_id: str | None
class TagUsage(NamedTuple):
name: str
tag_type: str
count: int
@dataclass(frozen=True)
class AssetSummaryData:
ref: ReferenceData
asset: AssetData | None
tags: list[str]
@dataclass(frozen=True)
class ListAssetsResult:
items: list[AssetSummaryData]
total: int
@dataclass(frozen=True)
class DownloadResolutionResult:
abs_path: str
content_type: str
download_name: str
@dataclass(frozen=True)
class UploadResult:
ref: ReferenceData
asset: AssetData
tags: list[str]
created_new: bool
def extract_reference_data(ref: AssetReference) -> ReferenceData:
return ReferenceData(
id=ref.id,
name=ref.name,
file_path=ref.file_path,
user_metadata=ref.user_metadata,
preview_id=ref.preview_id,
created_at=ref.created_at,
updated_at=ref.updated_at,
last_access_time=ref.last_access_time,
)
def extract_asset_data(asset: Asset | None) -> AssetData | None:
if asset is None:
return None
return AssetData(
hash=asset.hash,
size_bytes=asset.size_bytes,
mime_type=asset.mime_type,
)

View File

@ -0,0 +1,75 @@
from app.assets.database.queries import (
AddTagsResult,
RemoveTagsResult,
add_tags_to_reference,
get_reference_with_owner_check,
list_tags_with_usage,
remove_tags_from_reference,
)
from app.assets.services.schemas import TagUsage
from app.database.db import create_session
def apply_tags(
reference_id: str,
tags: list[str],
origin: str = "manual",
owner_id: str = "",
) -> AddTagsResult:
with create_session() as session:
ref_row = get_reference_with_owner_check(session, reference_id, owner_id)
result = add_tags_to_reference(
session,
reference_id=reference_id,
tags=tags,
origin=origin,
create_if_missing=True,
reference_row=ref_row,
)
session.commit()
return result
def remove_tags(
reference_id: str,
tags: list[str],
owner_id: str = "",
) -> RemoveTagsResult:
with create_session() as session:
get_reference_with_owner_check(session, reference_id, owner_id)
result = remove_tags_from_reference(
session,
reference_id=reference_id,
tags=tags,
)
session.commit()
return result
def list_tags(
prefix: str | None = None,
limit: int = 100,
offset: int = 0,
order: str = "count_desc",
include_zero: bool = True,
owner_id: str = "",
) -> tuple[list[TagUsage], int]:
limit = max(1, min(1000, limit))
offset = max(0, offset)
with create_session() as session:
rows, total = list_tags_with_usage(
session,
prefix=prefix,
limit=limit,
offset=offset,
include_zero=include_zero,
order=order,
owner_id=owner_id,
)
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total

View File

@ -3,6 +3,7 @@ import os
import shutil
from app.logger import log_startup_warning
from utils.install_util import get_missing_requirements_message
from filelock import FileLock, Timeout
from comfy.cli_args import args
_DB_AVAILABLE = False
@ -14,8 +15,12 @@ try:
from alembic.config import Config
from alembic.runtime.migration import MigrationContext
from alembic.script import ScriptDirectory
from sqlalchemy import create_engine
from sqlalchemy import create_engine, event
from sqlalchemy.orm import sessionmaker
from sqlalchemy.pool import StaticPool
from app.database.models import Base
import app.assets.database.models # noqa: F401 — register models with Base.metadata
_DB_AVAILABLE = True
except ImportError as e:
@ -65,9 +70,69 @@ def get_db_path():
raise ValueError(f"Unsupported database URL '{url}'.")
_db_lock = None
def _acquire_file_lock(db_path):
"""Acquire an OS-level file lock to prevent multi-process access.
Uses filelock for cross-platform support (macOS, Linux, Windows).
The OS automatically releases the lock when the process exits, even on crashes.
"""
global _db_lock
lock_path = db_path + ".lock"
_db_lock = FileLock(lock_path)
try:
_db_lock.acquire(timeout=0)
except Timeout:
raise RuntimeError(
f"Could not acquire lock on database '{db_path}'. "
"Another ComfyUI process may already be using it. "
"Use --database-url to specify a separate database file."
)
def _is_memory_db(db_url):
"""Check if the database URL refers to an in-memory SQLite database."""
return db_url in ("sqlite:///:memory:", "sqlite://")
def init_db():
db_url = args.database_url
logging.debug(f"Database URL: {db_url}")
if _is_memory_db(db_url):
_init_memory_db(db_url)
else:
_init_file_db(db_url)
def _init_memory_db(db_url):
"""Initialize an in-memory SQLite database using metadata.create_all.
Alembic migrations don't work with in-memory SQLite because each
connection gets its own separate database tables created by Alembic's
internal connection are lost immediately.
"""
engine = create_engine(
db_url,
poolclass=StaticPool,
connect_args={"check_same_thread": False},
)
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
Base.metadata.create_all(engine)
global Session
Session = sessionmaker(bind=engine)
def _init_file_db(db_url):
"""Initialize a file-backed SQLite database using Alembic migrations."""
db_path = get_db_path()
db_exists = os.path.exists(db_path)
@ -75,6 +140,14 @@ def init_db():
# Check if we need to upgrade
engine = create_engine(db_url)
# Enable foreign key enforcement for SQLite
@event.listens_for(engine, "connect")
def set_sqlite_pragma(dbapi_connection, connection_record):
cursor = dbapi_connection.cursor()
cursor.execute("PRAGMA foreign_keys=ON")
cursor.close()
conn = engine.connect()
context = MigrationContext.configure(conn)
@ -104,6 +177,12 @@ def init_db():
logging.exception("Error upgrading database: ")
raise e
# Acquire an OS-level file lock after migrations are complete.
# Alembic uses its own connection, so we must wait until it's done
# before locking — otherwise our own lock blocks the migration.
conn.close()
_acquire_file_lock(db_path)
global Session
Session = sessionmaker(bind=engine)

View File

@ -27,6 +27,7 @@ class AudioEncoderModel():
self.model.eval()
self.patcher = comfy.model_patcher.CoreModelPatcher(self.model, load_device=self.load_device, offload_device=offload_device)
self.model_sample_rate = 16000
comfy.model_management.archive_model_dtypes(self.model)
def load_sd(self, sd):
return self.model.load_state_dict(sd, strict=False, assign=self.patcher.is_dynamic())

View File

@ -234,7 +234,7 @@ database_default_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "user", "comfyui.db")
)
parser.add_argument("--database-url", type=str, default=f"sqlite:///{database_default_path}", help="Specify the database URL, e.g. for an in-memory database you can use 'sqlite:///:memory:'.")
parser.add_argument("--disable-assets-autoscan", action="store_true", help="Disable asset scanning on startup for database synchronization.")
parser.add_argument("--enable-assets", action="store_true", help="Enable the assets system (API routes, database synchronization, and background scanning).")
if comfy.options.args_parsing:
args = parser.parse_args()

View File

@ -395,6 +395,105 @@ class ComfyUIAdapter(IsolationAdapter):
registry.register("ndarray", serialize_numpy, None)
def serialize_ply(obj: Any) -> Dict[str, Any]:
import base64
import torch
if obj.raw_data is not None:
return {
"__type__": "PLY",
"raw_data": base64.b64encode(obj.raw_data).decode("ascii"),
}
result: Dict[str, Any] = {"__type__": "PLY", "points": torch.from_numpy(obj.points)}
if obj.colors is not None:
result["colors"] = torch.from_numpy(obj.colors)
if obj.confidence is not None:
result["confidence"] = torch.from_numpy(obj.confidence)
if obj.view_id is not None:
result["view_id"] = torch.from_numpy(obj.view_id)
return result
def deserialize_ply(data: Any) -> Any:
import base64
from comfy_api.latest._util.ply_types import PLY
if "raw_data" in data:
return PLY(raw_data=base64.b64decode(data["raw_data"]))
return PLY(
points=data["points"],
colors=data.get("colors"),
confidence=data.get("confidence"),
view_id=data.get("view_id"),
)
registry.register("PLY", serialize_ply, deserialize_ply, data_type=True)
def serialize_npz(obj: Any) -> Dict[str, Any]:
import base64
return {
"__type__": "NPZ",
"frames": [base64.b64encode(f).decode("ascii") for f in obj.frames],
}
def deserialize_npz(data: Any) -> Any:
import base64
from comfy_api.latest._util.npz_types import NPZ
return NPZ(frames=[base64.b64decode(f) for f in data["frames"]])
registry.register("NPZ", serialize_npz, deserialize_npz, data_type=True)
def serialize_file3d(obj: Any) -> Dict[str, Any]:
import base64
return {
"__type__": "File3D",
"format": obj.format,
"data": base64.b64encode(obj.get_bytes()).decode("ascii"),
}
def deserialize_file3d(data: Any) -> Any:
import base64
from io import BytesIO
from comfy_api.latest._util.geometry_types import File3D
return File3D(BytesIO(base64.b64decode(data["data"])), file_format=data["format"])
registry.register("File3D", serialize_file3d, deserialize_file3d, data_type=True)
def serialize_video(obj: Any) -> Dict[str, Any]:
components = obj.get_components()
images = components.images.detach() if components.images.requires_grad else components.images
result: Dict[str, Any] = {
"__type__": "VIDEO",
"images": images,
"frame_rate_num": components.frame_rate.numerator,
"frame_rate_den": components.frame_rate.denominator,
}
if components.audio is not None:
waveform = components.audio["waveform"]
if waveform.requires_grad:
waveform = waveform.detach()
result["audio_waveform"] = waveform
result["audio_sample_rate"] = components.audio["sample_rate"]
if components.metadata is not None:
result["metadata"] = components.metadata
return result
def deserialize_video(data: Any) -> Any:
from fractions import Fraction
from comfy_api.latest._input_impl.video_types import VideoFromComponents
from comfy_api.latest._util.video_types import VideoComponents
audio = None
if "audio_waveform" in data:
audio = {"waveform": data["audio_waveform"], "sample_rate": data["audio_sample_rate"]}
components = VideoComponents(
images=data["images"],
frame_rate=Fraction(data["frame_rate_num"], data["frame_rate_den"]),
audio=audio,
metadata=data.get("metadata"),
)
return VideoFromComponents(components)
registry.register("VIDEO", serialize_video, deserialize_video, data_type=True)
registry.register("VideoFromFile", serialize_video, deserialize_video, data_type=True)
registry.register("VideoFromComponents", serialize_video, deserialize_video, data_type=True)
def provide_rpc_services(self) -> List[type[ProxiedSingleton]]:
return [
PromptServerService,
@ -423,6 +522,13 @@ class ComfyUIAdapter(IsolationAdapter):
for name in dir(instance):
if not name.startswith("_"):
setattr(folder_paths, name, getattr(instance, name))
# Fence: isolated children get writable temp inside sandbox
if os.environ.get("PYISOLATE_CHILD") == "1":
_child_temp = os.path.join("/tmp", "comfyui_temp")
os.makedirs(_child_temp, exist_ok=True)
folder_paths.temp_directory = _child_temp
return
if api_name == "ModelManagementProxy":

View File

@ -12,6 +12,8 @@ from typing import Callable, Dict, List, Tuple
import pyisolate
from pyisolate import ExtensionManager, ExtensionManagerConfig
from packaging.requirements import InvalidRequirement, Requirement
from packaging.utils import canonicalize_name
from .extension_wrapper import ComfyNodeExtension
from .manifest_loader import is_cache_valid, load_from_cache, save_to_cache
@ -63,6 +65,94 @@ def _normalize_dependency_spec(dep: str, base_paths: list[Path]) -> str:
return dep
def _dependency_name_from_spec(dep: str) -> str | None:
stripped = dep.strip()
if not stripped or stripped == "-e" or stripped.startswith("-e "):
return None
if stripped.startswith(("/", "./", "../", "file://")):
return None
try:
return canonicalize_name(Requirement(stripped).name)
except InvalidRequirement:
return None
def _parse_cuda_wheels_config(
tool_config: dict[str, object], dependencies: list[str]
) -> dict[str, object] | None:
raw_config = tool_config.get("cuda_wheels")
if raw_config is None:
return None
if not isinstance(raw_config, dict):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels] must be a table"
)
index_url = raw_config.get("index_url")
if not isinstance(index_url, str) or not index_url.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.index_url] must be a non-empty string"
)
packages = raw_config.get("packages")
if not isinstance(packages, list) or not all(
isinstance(package_name, str) and package_name.strip()
for package_name in packages
):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.packages] must be a list of non-empty strings"
)
declared_dependencies = {
dependency_name
for dep in dependencies
if (dependency_name := _dependency_name_from_spec(dep)) is not None
}
normalized_packages = [canonicalize_name(package_name) for package_name in packages]
missing = [
package_name
for package_name in normalized_packages
if package_name not in declared_dependencies
]
if missing:
missing_joined = ", ".join(sorted(missing))
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.packages] references undeclared dependencies: "
f"{missing_joined}"
)
package_map = raw_config.get("package_map", {})
if not isinstance(package_map, dict):
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] must be a table"
)
normalized_package_map: dict[str, str] = {}
for dependency_name, index_package_name in package_map.items():
if not isinstance(dependency_name, str) or not dependency_name.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] keys must be non-empty strings"
)
if not isinstance(index_package_name, str) or not index_package_name.strip():
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] values must be non-empty strings"
)
canonical_dependency_name = canonicalize_name(dependency_name)
if canonical_dependency_name not in normalized_packages:
raise ExtensionLoadError(
"[tool.comfy.isolation.cuda_wheels.package_map] can only override packages listed in "
"[tool.comfy.isolation.cuda_wheels.packages]"
)
normalized_package_map[canonical_dependency_name] = index_package_name.strip()
return {
"index_url": index_url.rstrip("/") + "/",
"packages": normalized_packages,
"package_map": normalized_package_map,
}
def get_enforcement_policy() -> Dict[str, bool]:
return {
"force_isolated": os.environ.get("PYISOLATE_ENFORCE_ISOLATED") == "1",
@ -138,6 +228,7 @@ async def load_isolated_node(
_normalize_dependency_spec(dep, base_paths) if isinstance(dep, str) else dep
for dep in dependencies
]
cuda_wheels = _parse_cuda_wheels_config(tool_config, dependencies)
manager_config = ExtensionManagerConfig(venv_root_path=str(venv_root))
manager: ExtensionManager = pyisolate.ExtensionManager(
@ -166,6 +257,8 @@ async def load_isolated_node(
"share_cuda_ipc": share_cuda_ipc,
"sandbox": sandbox_config,
}
if cuda_wheels is not None:
extension_config["cuda_wheels"] = cuda_wheels
extension = manager.load_extension(extension_config)
register_dummy_module(extension_name, node_dir)

View File

@ -306,7 +306,7 @@ class ComfyNodeExtension(ExtensionBase):
node_name,
len(inputs),
)
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1":
if os.environ.get("PYISOLATE_CHILD") == "1":
_relieve_child_vram_pressure("EXT:pre_execute")
resolved_inputs = self._resolve_remote_objects(inputs)
@ -326,6 +326,13 @@ class ComfyNodeExtension(ExtensionBase):
"Hidden.dynprompt": Hidden.dynprompt,
"Hidden.auth_token_comfy_org": Hidden.auth_token_comfy_org,
"Hidden.api_key_comfy_org": Hidden.api_key_comfy_org,
# Uppercase enum VALUE forms — V3 execution engine passes these
"UNIQUE_ID": Hidden.unique_id,
"PROMPT": Hidden.prompt,
"EXTRA_PNGINFO": Hidden.extra_pnginfo,
"DYNPROMPT": Hidden.dynprompt,
"AUTH_TOKEN_COMFY_ORG": Hidden.auth_token_comfy_org,
"API_KEY_COMFY_ORG": Hidden.api_key_comfy_org,
}
# Find and extract hidden parameters (both enum and string form)
@ -383,29 +390,16 @@ class ComfyNodeExtension(ExtensionBase):
if type(result).__name__ == "NodeOutput":
result = result.args
if self._is_comfy_protocol_return(result):
logger.debug(
"%s ISO:child_execute_done ext=%s node=%s protocol_return=1",
LOG_PREFIX,
getattr(self, "name", "?"),
node_name,
)
wrapped = self._wrap_unpicklable_objects(result)
return wrapped
if not isinstance(result, tuple):
result = (result,)
logger.debug(
"%s ISO:child_execute_done ext=%s node=%s protocol_return=0 outputs=%d",
LOG_PREFIX,
getattr(self, "name", "?"),
node_name,
len(result),
)
wrapped = self._wrap_unpicklable_objects(result)
return wrapped
async def flush_transport_state(self) -> int:
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") != "1":
if os.environ.get("PYISOLATE_CHILD") != "1":
return 0
logger.debug(
"%s ISO:child_flush_start ext=%s", LOG_PREFIX, getattr(self, "name", "?")
@ -493,6 +487,14 @@ class ComfyNodeExtension(ExtensionBase):
}
return {"__pyisolate_attrdict__": True, "data": converted_dict}
from pyisolate._internal.serialization_registry import SerializerRegistry
registry = SerializerRegistry.get_instance()
if registry.is_data_type(type_name):
serializer = registry.get_serializer(type_name)
if serializer:
return serializer(data)
object_id = str(uuid.uuid4())
self.remote_objects[object_id] = data
return RemoteObjectHandle(object_id, type(data).__name__)

View File

@ -6,6 +6,8 @@ import logging
import os
from typing import Any
from comfy.cli_args import args
logger = logging.getLogger(__name__)
@ -13,8 +15,8 @@ def maybe_wrap_model_for_isolation(model_patcher: Any) -> Any:
from comfy.isolation.model_patcher_proxy_registry import ModelPatcherRegistry
from comfy.isolation.model_patcher_proxy import ModelPatcherProxy
isolation_active = os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1"
is_child = os.environ.get("PYISOLATE_CHILD") == "1"
isolation_active = args.use_process_isolation or is_child
if not isolation_active:
return model_patcher

View File

@ -176,7 +176,7 @@ def build_stub_class(
)
scan_shm_forensics("RUNTIME:execute_start", refresh_model_context=True)
try:
if os.environ.get("PYISOLATE_ISOLATION_ACTIVE") == "1":
if os.environ.get("PYISOLATE_CHILD") != "1":
_relieve_host_vram_pressure("RUNTIME:pre_execute", logger)
scan_shm_forensics("RUNTIME:pre_execute", refresh_model_context=True)
from pyisolate._internal.model_serialization import (

View File

@ -776,3 +776,10 @@ class ChromaRadiance(LatentFormat):
def process_out(self, latent):
return latent
class ZImagePixelSpace(ChromaRadiance):
"""Pixel-space latent format for ZImage DCT variant.
No VAE encoding/decoding the model operates directly on RGB pixels.
"""
pass

View File

@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor * m_mult
else:
for d in modulation_dims:
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]]
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
if m_add is not None:
tensor[:, d[0]:d[1]] += m_add[:, d[2]]
tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
return tensor
@ -223,12 +223,19 @@ class DoubleStreamBlock(nn.Module):
del txt_k, img_k
v = torch.cat((txt_v, img_v), dim=2)
del txt_v, img_v
extra_options["img_slice"] = [txt.shape[1], q.shape[2]]
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
# run actual attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v
if "attn1_output_patch" in transformer_patches:
extra_options["img_slice"] = [txt.shape[1], attn.shape[1]]
patch = transformer_patches["attn1_output_patch"]
for p in patch:
attn = p(attn, extra_options)
@ -321,6 +328,12 @@ class SingleStreamBlock(nn.Module):
del qkv
q, k = self.norm(q, k, v)
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(q, k, v, pe=pe, attn_mask=attn_mask, extra_options=extra_options)
q, k, v, pe, attn_mask = out.get("q", q), out.get("k", k), out.get("v", v), out.get("pe", pe), out.get("attn_mask", attn_mask)
# compute attention
attn = attention(q, k, v, pe=pe, mask=attn_mask, transformer_options=transformer_options)
del q, k, v

View File

@ -31,6 +31,8 @@ def rope(pos: Tensor, dim: int, theta: int) -> Tensor:
def _apply_rope1(x: Tensor, freqs_cis: Tensor):
x_ = x.to(dtype=freqs_cis.dtype).reshape(*x.shape[:-1], -1, 1, 2)
if x_.shape[2] != 1 and freqs_cis.shape[2] != 1 and x_.shape[2] != freqs_cis.shape[2]:
freqs_cis = freqs_cis[:, :, :x_.shape[2]]
x_out = freqs_cis[..., 0] * x_[..., 0]
x_out.addcmul_(freqs_cis[..., 1], x_[..., 1])

View File

@ -170,7 +170,7 @@ class Flux(nn.Module):
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids})
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
img = out["img"]
txt = out["txt"]
img_ids = out["img_ids"]

View File

@ -2,11 +2,16 @@ from typing import Tuple
import torch
import torch.nn as nn
from comfy.ldm.lightricks.model import (
ADALN_BASE_PARAMS_COUNT,
ADALN_CROSS_ATTN_PARAMS_COUNT,
CrossAttention,
FeedForward,
AdaLayerNormSingle,
PixArtAlphaTextProjection,
NormSingleLinearTextProjection,
LTXVModel,
apply_cross_attention_adaln,
compute_prompt_timestep,
)
from comfy.ldm.lightricks.symmetric_patchifier import AudioPatchifier
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
@ -87,6 +92,8 @@ class BasicAVTransformerBlock(nn.Module):
v_context_dim=None,
a_context_dim=None,
attn_precision=None,
apply_gated_attention=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@ -94,6 +101,7 @@ class BasicAVTransformerBlock(nn.Module):
super().__init__()
self.attn_precision = attn_precision
self.cross_attention_adaln = cross_attention_adaln
self.attn1 = CrossAttention(
query_dim=v_dim,
@ -101,6 +109,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=vd_head,
context_dim=None,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@ -111,6 +120,7 @@ class BasicAVTransformerBlock(nn.Module):
dim_head=ad_head,
context_dim=None,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@ -122,6 +132,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=v_heads,
dim_head=vd_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@ -132,6 +143,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@ -144,6 +156,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@ -156,6 +169,7 @@ class BasicAVTransformerBlock(nn.Module):
heads=a_heads,
dim_head=ad_head,
attn_precision=self.attn_precision,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@ -168,11 +182,16 @@ class BasicAVTransformerBlock(nn.Module):
a_dim, dim_out=a_dim, glu=True, dtype=dtype, device=device, operations=operations
)
self.scale_shift_table = nn.Parameter(torch.empty(6, v_dim, device=device, dtype=dtype))
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, v_dim, device=device, dtype=dtype))
self.audio_scale_shift_table = nn.Parameter(
torch.empty(6, a_dim, device=device, dtype=dtype)
torch.empty(num_ada_params, a_dim, device=device, dtype=dtype)
)
if cross_attention_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, v_dim, device=device, dtype=dtype))
self.audio_prompt_scale_shift_table = nn.Parameter(torch.empty(2, a_dim, device=device, dtype=dtype))
self.scale_shift_table_a2v_ca_audio = nn.Parameter(
torch.empty(5, a_dim, device=device, dtype=dtype)
)
@ -215,10 +234,30 @@ class BasicAVTransformerBlock(nn.Module):
return (*scale_shift_ada_values, *gate_ada_values)
def _apply_text_cross_attention(
self, x, context, attn, scale_shift_table, prompt_scale_shift_table,
timestep, prompt_timestep, attention_mask, transformer_options,
):
"""Apply text cross-attention, with optional ADaLN modulation."""
if self.cross_attention_adaln:
shift_q, scale_q, gate = self.get_ada_values(
scale_shift_table, x.shape[0], timestep, slice(6, 9)
)
return apply_cross_attention_adaln(
x, context, attn, shift_q, scale_q, gate,
prompt_scale_shift_table, prompt_timestep,
attention_mask, transformer_options,
)
return attn(
comfy.ldm.common_dit.rms_norm(x), context=context,
mask=attention_mask, transformer_options=transformer_options,
)
def forward(
self, x: Tuple[torch.Tensor, torch.Tensor], v_context=None, a_context=None, attention_mask=None, v_timestep=None, a_timestep=None,
v_pe=None, a_pe=None, v_cross_pe=None, a_cross_pe=None, v_cross_scale_shift_timestep=None, a_cross_scale_shift_timestep=None,
v_cross_gate_timestep=None, a_cross_gate_timestep=None, transformer_options=None, self_attention_mask=None,
v_prompt_timestep=None, a_prompt_timestep=None,
) -> Tuple[torch.Tensor, torch.Tensor]:
run_vx = transformer_options.get("run_vx", True)
run_ax = transformer_options.get("run_ax", True)
@ -240,7 +279,11 @@ class BasicAVTransformerBlock(nn.Module):
vgate_msa = self.get_ada_values(self.scale_shift_table, vx.shape[0], v_timestep, slice(2, 3))[0]
vx.addcmul_(attn1_out, vgate_msa)
del vgate_msa, attn1_out
vx.add_(self.attn2(comfy.ldm.common_dit.rms_norm(vx), context=v_context, mask=attention_mask, transformer_options=transformer_options))
vx.add_(self._apply_text_cross_attention(
vx, v_context, self.attn2, self.scale_shift_table,
getattr(self, 'prompt_scale_shift_table', None),
v_timestep, v_prompt_timestep, attention_mask, transformer_options,)
)
# audio
if run_ax:
@ -254,7 +297,11 @@ class BasicAVTransformerBlock(nn.Module):
agate_msa = self.get_ada_values(self.audio_scale_shift_table, ax.shape[0], a_timestep, slice(2, 3))[0]
ax.addcmul_(attn1_out, agate_msa)
del agate_msa, attn1_out
ax.add_(self.audio_attn2(comfy.ldm.common_dit.rms_norm(ax), context=a_context, mask=attention_mask, transformer_options=transformer_options))
ax.add_(self._apply_text_cross_attention(
ax, a_context, self.audio_attn2, self.audio_scale_shift_table,
getattr(self, 'audio_prompt_scale_shift_table', None),
a_timestep, a_prompt_timestep, attention_mask, transformer_options,)
)
# video - audio cross attention.
if run_a2v or run_v2a:
@ -351,6 +398,9 @@ class LTXAVModel(LTXVModel):
use_middle_indices_grid=False,
timestep_scale_multiplier=1000.0,
av_ca_timestep_scale_multiplier=1.0,
apply_gated_attention=False,
caption_proj_before_connector=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@ -362,6 +412,7 @@ class LTXAVModel(LTXVModel):
self.audio_attention_head_dim = audio_attention_head_dim
self.audio_num_attention_heads = audio_num_attention_heads
self.audio_positional_embedding_max_pos = audio_positional_embedding_max_pos
self.apply_gated_attention = apply_gated_attention
# Calculate audio dimensions
self.audio_inner_dim = audio_num_attention_heads * audio_attention_head_dim
@ -386,6 +437,8 @@ class LTXAVModel(LTXVModel):
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
caption_proj_before_connector=caption_proj_before_connector,
cross_attention_adaln=cross_attention_adaln,
dtype=dtype,
device=device,
operations=operations,
@ -400,14 +453,28 @@ class LTXAVModel(LTXVModel):
)
# Audio-specific AdaLN
audio_embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.audio_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=audio_embedding_coefficient,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.cross_attention_adaln:
self.audio_prompt_adaln_single = AdaLayerNormSingle(
self.audio_inner_dim,
embedding_coefficient=2,
use_additional_conditions=False,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.audio_prompt_adaln_single = None
num_scale_shift_values = 4
self.av_ca_video_scale_shift_adaln_single = AdaLayerNormSingle(
self.inner_dim,
@ -443,35 +510,73 @@ class LTXAVModel(LTXVModel):
)
# Audio caption projection
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.caption_proj_before_connector:
if self.caption_projection_first_linear:
self.audio_caption_projection = NormSingleLinearTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.audio_caption_projection = lambda a: a
else:
self.audio_caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.audio_inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
connector_split_rope = kwargs.get("rope_type", "split") == "split"
connector_gated_attention = kwargs.get("connector_apply_gated_attention", False)
attention_head_dim = kwargs.get("connector_attention_head_dim", 128)
num_attention_heads = kwargs.get("connector_num_attention_heads", 30)
num_layers = kwargs.get("connector_num_layers", 2)
self.audio_embeddings_connector = Embeddings1DConnector(
split_rope=True,
attention_head_dim=kwargs.get("audio_connector_attention_head_dim", attention_head_dim),
num_attention_heads=kwargs.get("audio_connector_num_attention_heads", num_attention_heads),
num_layers=num_layers,
split_rope=connector_split_rope,
double_precision_rope=True,
apply_gated_attention=connector_gated_attention,
dtype=dtype,
device=device,
operations=self.operations,
)
self.video_embeddings_connector = Embeddings1DConnector(
split_rope=True,
attention_head_dim=attention_head_dim,
num_attention_heads=num_attention_heads,
num_layers=num_layers,
split_rope=connector_split_rope,
double_precision_rope=True,
apply_gated_attention=connector_gated_attention,
dtype=dtype,
device=device,
operations=self.operations,
)
def preprocess_text_embeds(self, context):
if context.shape[-1] == self.caption_channels * 2:
return context
out_vid = self.video_embeddings_connector(context)[0]
out_audio = self.audio_embeddings_connector(context)[0]
def preprocess_text_embeds(self, context, unprocessed=False):
# LTXv2 fully processed context has dimension of self.caption_channels * 2
# LTXv2.3 fully processed context has dimension of self.cross_attention_dim + self.audio_cross_attention_dim
if not unprocessed:
if context.shape[-1] in (self.cross_attention_dim + self.audio_cross_attention_dim, self.caption_channels * 2):
return context
if context.shape[-1] == self.cross_attention_dim + self.audio_cross_attention_dim:
context_vid = context[:, :, :self.cross_attention_dim]
context_audio = context[:, :, self.cross_attention_dim:]
else:
context_vid = context
context_audio = context
if self.caption_proj_before_connector:
context_vid = self.caption_projection(context_vid)
context_audio = self.audio_caption_projection(context_audio)
out_vid = self.video_embeddings_connector(context_vid)[0]
out_audio = self.audio_embeddings_connector(context_audio)[0]
return torch.concat((out_vid, out_audio), dim=-1)
def _init_transformer_blocks(self, device, dtype, **kwargs):
@ -487,6 +592,8 @@ class LTXAVModel(LTXVModel):
ad_head=self.audio_attention_head_dim,
v_context_dim=self.cross_attention_dim,
a_context_dim=self.audio_cross_attention_dim,
apply_gated_attention=self.apply_gated_attention,
cross_attention_adaln=self.cross_attention_adaln,
dtype=dtype,
device=device,
operations=self.operations,
@ -608,6 +715,10 @@ class LTXAVModel(LTXVModel):
v_timestep = CompressedTimestep(v_timestep.view(batch_size, -1, v_timestep.shape[-1]), v_patches_per_frame)
v_embedded_timestep = CompressedTimestep(v_embedded_timestep.view(batch_size, -1, v_embedded_timestep.shape[-1]), v_patches_per_frame)
v_prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
)
# Prepare audio timestep
a_timestep = kwargs.get("a_timestep")
if a_timestep is not None:
@ -618,25 +729,25 @@ class LTXAVModel(LTXVModel):
# Cross-attention timesteps - compress these too
av_ca_audio_scale_shift_timestep, _ = self.av_ca_audio_scale_shift_adaln_single(
a_timestep_flat,
timestep.max().expand_as(a_timestep_flat),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_video_scale_shift_timestep, _ = self.av_ca_video_scale_shift_adaln_single(
timestep_flat,
a_timestep.max().expand_as(timestep_flat),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_a2v_gate_noise_timestep, _ = self.av_ca_a2v_gate_adaln_single(
timestep_flat * av_ca_factor,
a_timestep.max().expand_as(timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
av_ca_v2a_gate_noise_timestep, _ = self.av_ca_v2a_gate_adaln_single(
a_timestep_flat * av_ca_factor,
timestep.max().expand_as(a_timestep_flat) * av_ca_factor,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
@ -660,29 +771,40 @@ class LTXAVModel(LTXVModel):
# Audio timesteps
a_timestep = a_timestep.view(batch_size, -1, a_timestep.shape[-1])
a_embedded_timestep = a_embedded_timestep.view(batch_size, -1, a_embedded_timestep.shape[-1])
a_prompt_timestep = compute_prompt_timestep(
self.audio_prompt_adaln_single, a_timestep_scaled, batch_size, hidden_dtype
)
else:
a_timestep = timestep_scaled
a_embedded_timestep = kwargs.get("embedded_timestep")
cross_av_timestep_ss = []
a_prompt_timestep = None
return [v_timestep, a_timestep, cross_av_timestep_ss], [
return [v_timestep, a_timestep, cross_av_timestep_ss, v_prompt_timestep, a_prompt_timestep], [
v_embedded_timestep,
a_embedded_timestep,
]
], None
def _prepare_context(self, context, batch_size, x, attention_mask=None):
vx = x[0]
ax = x[1]
video_dim = vx.shape[-1]
audio_dim = ax.shape[-1]
v_context_dim = self.caption_channels if self.caption_proj_before_connector is False else video_dim
a_context_dim = self.caption_channels if self.caption_proj_before_connector is False else audio_dim
v_context, a_context = torch.split(
context, int(context.shape[-1] / 2), len(context.shape) - 1
context, [v_context_dim, a_context_dim], len(context.shape) - 1
)
v_context, attention_mask = super()._prepare_context(
v_context, batch_size, vx, attention_mask
)
if self.audio_caption_projection is not None:
if self.caption_proj_before_connector is False:
a_context = self.audio_caption_projection(a_context)
a_context = a_context.view(batch_size, -1, ax.shape[-1])
a_context = a_context.view(batch_size, -1, audio_dim)
return [v_context, a_context], attention_mask
@ -744,6 +866,9 @@ class LTXAVModel(LTXVModel):
av_ca_v2a_gate_noise_timestep,
) = timestep[2]
v_prompt_timestep = timestep[3]
a_prompt_timestep = timestep[4]
"""Process transformer blocks for LTXAV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
@ -771,6 +896,8 @@ class LTXAVModel(LTXVModel):
a_cross_gate_timestep=args["a_cross_gate_timestep"],
transformer_options=args["transformer_options"],
self_attention_mask=args.get("self_attention_mask"),
v_prompt_timestep=args.get("v_prompt_timestep"),
a_prompt_timestep=args.get("a_prompt_timestep"),
)
return out
@ -792,6 +919,8 @@ class LTXAVModel(LTXVModel):
"a_cross_gate_timestep": av_ca_v2a_gate_noise_timestep,
"transformer_options": transformer_options,
"self_attention_mask": self_attention_mask,
"v_prompt_timestep": v_prompt_timestep,
"a_prompt_timestep": a_prompt_timestep,
},
{"original_block": block_wrap},
)
@ -814,6 +943,8 @@ class LTXAVModel(LTXVModel):
a_cross_gate_timestep=av_ca_v2a_gate_noise_timestep,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
v_prompt_timestep=v_prompt_timestep,
a_prompt_timestep=a_prompt_timestep,
)
return [vx, ax]

View File

@ -50,6 +50,7 @@ class BasicTransformerBlock1D(nn.Module):
d_head,
context_dim=None,
attn_precision=None,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@ -63,6 +64,7 @@ class BasicTransformerBlock1D(nn.Module):
heads=n_heads,
dim_head=d_head,
context_dim=None,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,
@ -121,6 +123,7 @@ class Embeddings1DConnector(nn.Module):
positional_embedding_max_pos=[4096],
causal_temporal_positioning=False,
num_learnable_registers: Optional[int] = 128,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@ -145,6 +148,7 @@ class Embeddings1DConnector(nn.Module):
num_attention_heads,
attention_head_dim,
context_dim=cross_attention_dim,
apply_gated_attention=apply_gated_attention,
dtype=dtype,
device=device,
operations=operations,

View File

@ -275,6 +275,30 @@ class PixArtAlphaTextProjection(nn.Module):
return hidden_states
class NormSingleLinearTextProjection(nn.Module):
"""Text projection for 20B models - single linear with RMSNorm (no activation)."""
def __init__(
self, in_features, hidden_size, dtype=None, device=None, operations=None
):
super().__init__()
if operations is None:
operations = comfy.ops.disable_weight_init
self.in_norm = operations.RMSNorm(
in_features, eps=1e-6, elementwise_affine=False
)
self.linear_1 = operations.Linear(
in_features, hidden_size, bias=True, dtype=dtype, device=device
)
self.hidden_size = hidden_size
self.in_features = in_features
def forward(self, caption):
caption = self.in_norm(caption)
caption = caption * (self.hidden_size / self.in_features) ** 0.5
return self.linear_1(caption)
class GELU_approx(nn.Module):
def __init__(self, dim_in, dim_out, dtype=None, device=None, operations=None):
super().__init__()
@ -343,6 +367,7 @@ class CrossAttention(nn.Module):
dim_head=64,
dropout=0.0,
attn_precision=None,
apply_gated_attention=False,
dtype=None,
device=None,
operations=None,
@ -362,6 +387,12 @@ class CrossAttention(nn.Module):
self.to_k = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
self.to_v = operations.Linear(context_dim, inner_dim, bias=True, dtype=dtype, device=device)
# Optional per-head gating
if apply_gated_attention:
self.to_gate_logits = operations.Linear(query_dim, heads, bias=True, dtype=dtype, device=device)
else:
self.to_gate_logits = None
self.to_out = nn.Sequential(
operations.Linear(inner_dim, query_dim, dtype=dtype, device=device), nn.Dropout(dropout)
)
@ -383,16 +414,30 @@ class CrossAttention(nn.Module):
out = comfy.ldm.modules.attention.optimized_attention(q, k, v, self.heads, attn_precision=self.attn_precision, transformer_options=transformer_options)
else:
out = comfy.ldm.modules.attention.optimized_attention_masked(q, k, v, self.heads, mask, attn_precision=self.attn_precision, transformer_options=transformer_options)
# Apply per-head gating if enabled
if self.to_gate_logits is not None:
gate_logits = self.to_gate_logits(x) # (B, T, H)
b, t, _ = out.shape
out = out.view(b, t, self.heads, self.dim_head)
gates = 2.0 * torch.sigmoid(gate_logits) # zero-init -> identity
out = out * gates.unsqueeze(-1)
out = out.view(b, t, self.heads * self.dim_head)
return self.to_out(out)
# 6 base ADaLN params (shift/scale/gate for MSA + MLP), +3 for cross-attention Q (shift/scale/gate)
ADALN_BASE_PARAMS_COUNT = 6
ADALN_CROSS_ATTN_PARAMS_COUNT = 9
class BasicTransformerBlock(nn.Module):
def __init__(
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, dtype=None, device=None, operations=None
self, dim, n_heads, d_head, context_dim=None, attn_precision=None, cross_attention_adaln=False, dtype=None, device=None, operations=None
):
super().__init__()
self.attn_precision = attn_precision
self.cross_attention_adaln = cross_attention_adaln
self.attn1 = CrossAttention(
query_dim=dim,
heads=n_heads,
@ -416,18 +461,25 @@ class BasicTransformerBlock(nn.Module):
operations=operations,
)
self.scale_shift_table = nn.Parameter(torch.empty(6, dim, device=device, dtype=dtype))
num_ada_params = ADALN_CROSS_ATTN_PARAMS_COUNT if cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.scale_shift_table = nn.Parameter(torch.empty(num_ada_params, dim, device=device, dtype=dtype))
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)).unbind(dim=2)
if cross_attention_adaln:
self.prompt_scale_shift_table = nn.Parameter(torch.empty(2, dim, device=device, dtype=dtype))
attn1_input = comfy.ldm.common_dit.rms_norm(x)
attn1_input = torch.addcmul(attn1_input, attn1_input, scale_msa).add_(shift_msa)
attn1_input = self.attn1(attn1_input, pe=pe, mask=self_attention_mask, transformer_options=transformer_options)
x.addcmul_(attn1_input, gate_msa)
del attn1_input
def forward(self, x, context=None, attention_mask=None, timestep=None, pe=None, transformer_options={}, self_attention_mask=None, prompt_timestep=None):
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = (self.scale_shift_table[None, None, :6].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, :6, :]).unbind(dim=2)
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
x += self.attn1(comfy.ldm.common_dit.rms_norm(x) * (1 + scale_msa) + shift_msa, pe=pe, mask=self_attention_mask, transformer_options=transformer_options) * gate_msa
if self.cross_attention_adaln:
shift_q_mca, scale_q_mca, gate_mca = (self.scale_shift_table[None, None, 6:9].to(device=x.device, dtype=x.dtype) + timestep.reshape(x.shape[0], timestep.shape[1], self.scale_shift_table.shape[0], -1)[:, :, 6:9, :]).unbind(dim=2)
x += apply_cross_attention_adaln(
x, context, self.attn2, shift_q_mca, scale_q_mca, gate_mca,
self.prompt_scale_shift_table, prompt_timestep, attention_mask, transformer_options,
)
else:
x += self.attn2(x, context=context, mask=attention_mask, transformer_options=transformer_options)
y = comfy.ldm.common_dit.rms_norm(x)
y = torch.addcmul(y, y, scale_mlp).add_(shift_mlp)
@ -435,6 +487,47 @@ class BasicTransformerBlock(nn.Module):
return x
def compute_prompt_timestep(adaln_module, timestep_scaled, batch_size, hidden_dtype):
"""Compute a single global prompt timestep for cross-attention ADaLN.
Uses the max across tokens (matching JAX max_per_segment) and broadcasts
over text tokens. Returns None when *adaln_module* is None.
"""
if adaln_module is None:
return None
ts_input = (
timestep_scaled.max(dim=1, keepdim=True).values.flatten()
if timestep_scaled.dim() > 1
else timestep_scaled.flatten()
)
prompt_ts, _ = adaln_module(
ts_input,
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
)
return prompt_ts.view(batch_size, 1, prompt_ts.shape[-1])
def apply_cross_attention_adaln(
x, context, attn, q_shift, q_scale, q_gate,
prompt_scale_shift_table, prompt_timestep,
attention_mask=None, transformer_options={},
):
"""Apply cross-attention with ADaLN modulation (shift/scale/gate on Q and KV).
Q params (q_shift, q_scale, q_gate) are pre-extracted by the caller so
that both regular tensors and CompressedTimestep are supported.
"""
batch_size = x.shape[0]
shift_kv, scale_kv = (
prompt_scale_shift_table[None, None].to(device=x.device, dtype=x.dtype)
+ prompt_timestep.reshape(batch_size, prompt_timestep.shape[1], 2, -1)
).unbind(dim=2)
attn_input = comfy.ldm.common_dit.rms_norm(x) * (1 + q_scale) + q_shift
encoder_hidden_states = context * (1 + scale_kv) + shift_kv
return attn(attn_input, context=encoder_hidden_states, mask=attention_mask, transformer_options=transformer_options) * q_gate
def get_fractional_positions(indices_grid, max_pos):
n_pos_dims = indices_grid.shape[1]
assert n_pos_dims == len(max_pos), f'Number of position dimensions ({n_pos_dims}) must match max_pos length ({len(max_pos)})'
@ -556,6 +649,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
vae_scale_factors: tuple = (8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
caption_proj_before_connector=False,
cross_attention_adaln=False,
caption_projection_first_linear=True,
dtype=None,
device=None,
operations=None,
@ -582,6 +678,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
self.causal_temporal_positioning = causal_temporal_positioning
self.operations = operations
self.timestep_scale_multiplier = timestep_scale_multiplier
self.caption_proj_before_connector = caption_proj_before_connector
self.cross_attention_adaln = cross_attention_adaln
self.caption_projection_first_linear = caption_projection_first_linear
# Common dimensions
self.inner_dim = num_attention_heads * attention_head_dim
@ -609,17 +708,37 @@ class LTXBaseModel(torch.nn.Module, ABC):
self.in_channels, self.inner_dim, bias=True, dtype=dtype, device=device
)
embedding_coefficient = ADALN_CROSS_ATTN_PARAMS_COUNT if self.cross_attention_adaln else ADALN_BASE_PARAMS_COUNT
self.adaln_single = AdaLayerNormSingle(
self.inner_dim, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
self.inner_dim, embedding_coefficient=embedding_coefficient, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
)
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
if self.cross_attention_adaln:
self.prompt_adaln_single = AdaLayerNormSingle(
self.inner_dim, embedding_coefficient=2, use_additional_conditions=False, dtype=dtype, device=device, operations=self.operations
)
else:
self.prompt_adaln_single = None
if self.caption_proj_before_connector:
if self.caption_projection_first_linear:
self.caption_projection = NormSingleLinearTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
else:
self.caption_projection = lambda a: a
else:
self.caption_projection = PixArtAlphaTextProjection(
in_features=self.caption_channels,
hidden_size=self.inner_dim,
dtype=dtype,
device=device,
operations=self.operations,
)
@abstractmethod
def _init_model_components(self, device, dtype, **kwargs):
@ -665,9 +784,9 @@ class LTXBaseModel(torch.nn.Module, ABC):
if grid_mask is not None:
timestep = timestep[:, grid_mask]
timestep = timestep * self.timestep_scale_multiplier
timestep_scaled = timestep * self.timestep_scale_multiplier
timestep, embedded_timestep = self.adaln_single(
timestep.flatten(),
timestep_scaled.flatten(),
{"resolution": None, "aspect_ratio": None},
batch_size=batch_size,
hidden_dtype=hidden_dtype,
@ -677,14 +796,18 @@ class LTXBaseModel(torch.nn.Module, ABC):
timestep = timestep.view(batch_size, -1, timestep.shape[-1])
embedded_timestep = embedded_timestep.view(batch_size, -1, embedded_timestep.shape[-1])
return timestep, embedded_timestep
prompt_timestep = compute_prompt_timestep(
self.prompt_adaln_single, timestep_scaled, batch_size, hidden_dtype
)
return timestep, embedded_timestep, prompt_timestep
def _prepare_context(self, context, batch_size, x, attention_mask=None):
"""Prepare context for transformer blocks."""
if self.caption_projection is not None:
if self.caption_proj_before_connector is False:
context = self.caption_projection(context)
context = context.view(batch_size, -1, x.shape[-1])
context = context.view(batch_size, -1, x.shape[-1])
return context, attention_mask
def _precompute_freqs_cis(
@ -792,7 +915,8 @@ class LTXBaseModel(torch.nn.Module, ABC):
merged_args.update(additional_args)
# Prepare timestep and context
timestep, embedded_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
timestep, embedded_timestep, prompt_timestep = self._prepare_timestep(timestep, batch_size, input_dtype, **merged_args)
merged_args["prompt_timestep"] = prompt_timestep
context, attention_mask = self._prepare_context(context, batch_size, x, attention_mask)
# Prepare attention mask and positional embeddings
@ -833,7 +957,9 @@ class LTXVModel(LTXBaseModel):
causal_temporal_positioning=False,
vae_scale_factors=(8, 32, 32),
use_middle_indices_grid=False,
timestep_scale_multiplier = 1000.0,
timestep_scale_multiplier=1000.0,
caption_proj_before_connector=False,
cross_attention_adaln=False,
dtype=None,
device=None,
operations=None,
@ -852,6 +978,8 @@ class LTXVModel(LTXBaseModel):
vae_scale_factors=vae_scale_factors,
use_middle_indices_grid=use_middle_indices_grid,
timestep_scale_multiplier=timestep_scale_multiplier,
caption_proj_before_connector=caption_proj_before_connector,
cross_attention_adaln=cross_attention_adaln,
dtype=dtype,
device=device,
operations=operations,
@ -860,7 +988,6 @@ class LTXVModel(LTXBaseModel):
def _init_model_components(self, device, dtype, **kwargs):
"""Initialize LTXV-specific components."""
# No additional components needed for LTXV beyond base class
pass
def _init_transformer_blocks(self, device, dtype, **kwargs):
@ -872,6 +999,7 @@ class LTXVModel(LTXBaseModel):
self.num_attention_heads,
self.attention_head_dim,
context_dim=self.cross_attention_dim,
cross_attention_adaln=self.cross_attention_adaln,
dtype=dtype,
device=device,
operations=self.operations,
@ -1149,16 +1277,17 @@ class LTXVModel(LTXBaseModel):
"""Process transformer blocks for LTXV."""
patches_replace = transformer_options.get("patches_replace", {})
blocks_replace = patches_replace.get("dit", {})
prompt_timestep = kwargs.get("prompt_timestep", None)
for i, block in enumerate(self.transformer_blocks):
if ("double_block", i) in blocks_replace:
def block_wrap(args):
out = {}
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"))
out["img"] = block(args["img"], context=args["txt"], attention_mask=args["attention_mask"], timestep=args["vec"], pe=args["pe"], transformer_options=args["transformer_options"], self_attention_mask=args.get("self_attention_mask"), prompt_timestep=args.get("prompt_timestep"))
return out
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask}, {"original_block": block_wrap})
out = blocks_replace[("double_block", i)]({"img": x, "txt": context, "attention_mask": attention_mask, "vec": timestep, "pe": pe, "transformer_options": transformer_options, "self_attention_mask": self_attention_mask, "prompt_timestep": prompt_timestep}, {"original_block": block_wrap})
x = out["img"]
else:
x = block(
@ -1169,6 +1298,7 @@ class LTXVModel(LTXBaseModel):
pe=pe,
transformer_options=transformer_options,
self_attention_mask=self_attention_mask,
prompt_timestep=prompt_timestep,
)
return x

View File

@ -13,7 +13,7 @@ from comfy.ldm.lightricks.vae.causal_audio_autoencoder import (
CausalityAxis,
CausalAudioAutoencoder,
)
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder
from comfy.ldm.lightricks.vocoders.vocoder import Vocoder, VocoderWithBWE
LATENT_DOWNSAMPLE_FACTOR = 4
@ -141,7 +141,10 @@ class AudioVAE(torch.nn.Module):
vocoder_sd = utils.state_dict_prefix_replace(state_dict, {"vocoder.": ""}, filter_keys=True)
self.autoencoder = CausalAudioAutoencoder(config=component_config.autoencoder)
self.vocoder = Vocoder(config=component_config.vocoder)
if "bwe" in component_config.vocoder:
self.vocoder = VocoderWithBWE(config=component_config.vocoder)
else:
self.vocoder = Vocoder(config=component_config.vocoder)
self.autoencoder.load_state_dict(vae_sd, strict=False)
self.vocoder.load_state_dict(vocoder_sd, strict=False)

View File

@ -822,26 +822,23 @@ class CausalAudioAutoencoder(nn.Module):
super().__init__()
if config is None:
config = self._guess_config()
config = self.get_default_config()
# Extract encoder and decoder configs from the new format
model_config = config.get("model", {}).get("params", {})
variables_config = config.get("variables", {})
self.sampling_rate = variables_config.get(
"sampling_rate",
model_config.get("sampling_rate", config.get("sampling_rate", 16000)),
self.sampling_rate = model_config.get(
"sampling_rate", config.get("sampling_rate", 16000)
)
encoder_config = model_config.get("encoder", model_config.get("ddconfig", {}))
decoder_config = model_config.get("decoder", encoder_config)
# Load mel spectrogram parameters
self.mel_bins = encoder_config.get("mel_bins", 64)
self.mel_hop_length = model_config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = model_config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
self.mel_hop_length = config.get("preprocessing", {}).get("stft", {}).get("hop_length", 160)
self.n_fft = config.get("preprocessing", {}).get("stft", {}).get("filter_length", 1024)
# Store causality configuration at VAE level (not just in encoder internals)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.WIDTH.value)
causality_axis_value = encoder_config.get("causality_axis", CausalityAxis.HEIGHT.value)
self.causality_axis = CausalityAxis.str_to_enum(causality_axis_value)
self.is_causal = self.causality_axis == CausalityAxis.HEIGHT
@ -850,44 +847,38 @@ class CausalAudioAutoencoder(nn.Module):
self.per_channel_statistics = processor()
def _guess_config(self):
encoder_config = {
# Required parameters - based on ltx-video-av-1679000 model metadata
"ch": 128,
"out_ch": 8,
"ch_mult": [1, 2, 4], # Based on metadata: [1, 2, 4] not [1, 2, 4, 8]
"num_res_blocks": 2,
"attn_resolutions": [], # Based on metadata: empty list, no attention
"dropout": 0.0,
"resamp_with_conv": True,
"in_channels": 2, # stereo
"resolution": 256,
"z_channels": 8,
def get_default_config(self):
ddconfig = {
"double_z": True,
"attn_type": "vanilla",
"mid_block_add_attention": False, # Based on metadata: false
"mel_bins": 64,
"z_channels": 8,
"resolution": 256,
"downsample_time": False,
"in_channels": 2,
"out_ch": 2,
"ch": 128,
"ch_mult": [1, 2, 4],
"num_res_blocks": 2,
"attn_resolutions": [],
"dropout": 0.0,
"mid_block_add_attention": False,
"norm_type": "pixel",
"causality_axis": "height", # Based on metadata
"mel_bins": 64, # Based on metadata: mel_bins = 64
}
decoder_config = {
# Inherits encoder config, can override specific params
**encoder_config,
"out_ch": 2, # Stereo audio output (2 channels)
"give_pre_end": False,
"tanh_out": False,
"causality_axis": "height",
}
config = {
"_class_name": "CausalAudioAutoencoder",
"sampling_rate": 16000,
"model": {
"params": {
"encoder": encoder_config,
"decoder": decoder_config,
"ddconfig": ddconfig,
"sampling_rate": 16000,
}
},
"preprocessing": {
"stft": {
"filter_length": 1024,
"hop_length": 160,
},
},
}
return config

View File

@ -15,6 +15,9 @@ from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init
def in_meta_context():
return torch.device("meta") == torch.empty(0).device
def mark_conv3d_ended(module):
tid = threading.get_ident()
for _, m in module.named_modules():
@ -350,6 +353,10 @@ class Decoder(nn.Module):
output_channel = output_channel * block_params.get("multiplier", 2)
if block_name == "compress_all":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_space":
output_channel = output_channel * block_params.get("multiplier", 1)
if block_name == "compress_time":
output_channel = output_channel * block_params.get("multiplier", 1)
self.conv_in = make_conv_nd(
dims,
@ -395,17 +402,21 @@ class Decoder(nn.Module):
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_time":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(2, 1, 1),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_space":
output_channel = output_channel // block_params.get("multiplier", 1)
block = DepthToSpaceUpsample(
dims=dims,
in_channels=input_channel,
stride=(1, 2, 2),
out_channels_reduction_factor=block_params.get("multiplier", 1),
spatial_padding_mode=spatial_padding_mode,
)
elif block_name == "compress_all":
@ -455,6 +466,15 @@ class Decoder(nn.Module):
output_channel * 2, 0, operations=ops,
)
self.last_scale_shift_table = nn.Parameter(torch.empty(2, output_channel))
else:
self.register_buffer(
"last_scale_shift_table",
torch.tensor(
[0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(2, output_channel),
persistent=False,
)
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
@ -883,6 +903,15 @@ class ResnetBlock3D(nn.Module):
self.scale_shift_table = nn.Parameter(
torch.randn(4, in_channels) / in_channels**0.5
)
else:
self.register_buffer(
"scale_shift_table",
torch.tensor(
[0.0, 0.0, 0.0, 0.0],
device="cpu" if in_meta_context() else None
).unsqueeze(1).expand(4, in_channels),
persistent=False,
)
self.temporal_cache_state={}
@ -1012,9 +1041,6 @@ class processor(nn.Module):
super().__init__()
self.register_buffer("std-of-means", torch.empty(128))
self.register_buffer("mean-of-means", torch.empty(128))
self.register_buffer("mean-of-stds", torch.empty(128))
self.register_buffer("mean-of-stds_over_std-of-means", torch.empty(128))
self.register_buffer("channel", torch.empty(128))
def un_normalize(self, x):
return (x * self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)) + self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)
@ -1027,9 +1053,12 @@ class VideoVAE(nn.Module):
super().__init__()
if config is None:
config = self.guess_config(version)
config = self.get_default_config(version)
self.config = config
self.timestep_conditioning = config.get("timestep_conditioning", False)
self.decode_noise_scale = config.get("decode_noise_scale", 0.025)
self.decode_timestep = config.get("decode_timestep", 0.05)
double_z = config.get("double_z", True)
latent_log_var = config.get(
"latent_log_var", "per_channel" if double_z else "none"
@ -1044,6 +1073,7 @@ class VideoVAE(nn.Module):
latent_log_var=latent_log_var,
norm_layer=config.get("norm_layer", "group_norm"),
spatial_padding_mode=config.get("spatial_padding_mode", "zeros"),
base_channels=config.get("encoder_base_channels", 128),
)
self.decoder = Decoder(
@ -1051,6 +1081,7 @@ class VideoVAE(nn.Module):
in_channels=config["latent_channels"],
out_channels=config.get("out_channels", 3),
blocks=config.get("decoder_blocks", config.get("decoder_blocks", config.get("blocks"))),
base_channels=config.get("decoder_base_channels", 128),
patch_size=config.get("patch_size", 1),
norm_layer=config.get("norm_layer", "group_norm"),
causal=config.get("causal_decoder", False),
@ -1060,7 +1091,7 @@ class VideoVAE(nn.Module):
self.per_channel_statistics = processor()
def guess_config(self, version):
def get_default_config(self, version):
if version == 0:
config = {
"_class_name": "CausalVideoAutoencoder",
@ -1167,8 +1198,7 @@ class VideoVAE(nn.Module):
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means)
def decode(self, x, timestep=0.05, noise_scale=0.025):
def decode(self, x):
if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * noise_scale + (1.0 - noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=timestep)
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)

View File

@ -2,7 +2,9 @@ import torch
import torch.nn.functional as F
import torch.nn as nn
import comfy.ops
import comfy.model_management
import numpy as np
import math
ops = comfy.ops.disable_weight_init
@ -12,6 +14,307 @@ def get_padding(kernel_size, dilation=1):
return int((kernel_size * dilation - dilation) / 2)
# ---------------------------------------------------------------------------
# Anti-aliased resampling helpers (kaiser-sinc filters) for BigVGAN v2
# Adopted from https://github.com/NVIDIA/BigVGAN
# ---------------------------------------------------------------------------
def _sinc(x: torch.Tensor):
return torch.where(
x == 0,
torch.tensor(1.0, device=x.device, dtype=x.dtype),
torch.sin(math.pi * x) / math.pi / x,
)
def kaiser_sinc_filter1d(cutoff, half_width, kernel_size):
even = kernel_size % 2 == 0
half_size = kernel_size // 2
delta_f = 4 * half_width
A = 2.285 * (half_size - 1) * math.pi * delta_f + 7.95
if A > 50.0:
beta = 0.1102 * (A - 8.7)
elif A >= 21.0:
beta = 0.5842 * (A - 21) ** 0.4 + 0.07886 * (A - 21.0)
else:
beta = 0.0
window = torch.kaiser_window(kernel_size, beta=beta, periodic=False)
if even:
time = torch.arange(-half_size, half_size) + 0.5
else:
time = torch.arange(kernel_size) - half_size
if cutoff == 0:
filter_ = torch.zeros_like(time)
else:
filter_ = 2 * cutoff * window * _sinc(2 * cutoff * time)
filter_ /= filter_.sum()
filter = filter_.view(1, 1, kernel_size)
return filter
class LowPassFilter1d(nn.Module):
def __init__(
self,
cutoff=0.5,
half_width=0.6,
stride=1,
padding=True,
padding_mode="replicate",
kernel_size=12,
):
super().__init__()
if cutoff < -0.0:
raise ValueError("Minimum cutoff must be larger than zero.")
if cutoff > 0.5:
raise ValueError("A cutoff above 0.5 does not make sense.")
self.kernel_size = kernel_size
self.even = kernel_size % 2 == 0
self.pad_left = kernel_size // 2 - int(self.even)
self.pad_right = kernel_size // 2
self.stride = stride
self.padding = padding
self.padding_mode = padding_mode
filter = kaiser_sinc_filter1d(cutoff, half_width, kernel_size)
self.register_buffer("filter", filter)
def forward(self, x):
_, C, _ = x.shape
if self.padding:
x = F.pad(x, (self.pad_left, self.pad_right), mode=self.padding_mode)
return F.conv1d(x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C)
class UpSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None, persistent=True, window_type="kaiser"):
super().__init__()
self.ratio = ratio
self.stride = ratio
if window_type == "hann":
# Hann-windowed sinc filter — identical to torchaudio.functional.resample
# with its default parameters (rolloff=0.99, lowpass_filter_width=6).
# Uses replicate boundary padding, matching the reference resampler exactly.
rolloff = 0.99
lowpass_filter_width = 6
width = math.ceil(lowpass_filter_width / rolloff)
self.kernel_size = 2 * width * ratio + 1
self.pad = width
self.pad_left = 2 * width * ratio
self.pad_right = self.kernel_size - ratio
t = (torch.arange(self.kernel_size) / ratio - width) * rolloff
t_clamped = t.clamp(-lowpass_filter_width, lowpass_filter_width)
window = torch.cos(t_clamped * math.pi / lowpass_filter_width / 2) ** 2
filter = (torch.sinc(t) * window * rolloff / ratio).view(1, 1, -1)
else:
# Kaiser-windowed sinc filter (BigVGAN default).
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.pad = self.kernel_size // ratio - 1
self.pad_left = self.pad * self.stride + (self.kernel_size - self.stride) // 2
self.pad_right = (
self.pad * self.stride + (self.kernel_size - self.stride + 1) // 2
)
filter = kaiser_sinc_filter1d(
cutoff=0.5 / ratio, half_width=0.6 / ratio, kernel_size=self.kernel_size
)
self.register_buffer("filter", filter, persistent=persistent)
def forward(self, x):
_, C, _ = x.shape
x = F.pad(x, (self.pad, self.pad), mode="replicate")
x = self.ratio * F.conv_transpose1d(
x, comfy.model_management.cast_to(self.filter.expand(C, -1, -1), dtype=x.dtype, device=x.device), stride=self.stride, groups=C
)
x = x[..., self.pad_left : -self.pad_right]
return x
class DownSample1d(nn.Module):
def __init__(self, ratio=2, kernel_size=None):
super().__init__()
self.ratio = ratio
self.kernel_size = (
int(6 * ratio // 2) * 2 if kernel_size is None else kernel_size
)
self.lowpass = LowPassFilter1d(
cutoff=0.5 / ratio,
half_width=0.6 / ratio,
stride=ratio,
kernel_size=self.kernel_size,
)
def forward(self, x):
return self.lowpass(x)
class Activation1d(nn.Module):
def __init__(
self,
activation,
up_ratio=2,
down_ratio=2,
up_kernel_size=12,
down_kernel_size=12,
):
super().__init__()
self.act = activation
self.upsample = UpSample1d(up_ratio, up_kernel_size)
self.downsample = DownSample1d(down_ratio, down_kernel_size)
def forward(self, x):
x = self.upsample(x)
x = self.act(x)
x = self.downsample(x)
return x
# ---------------------------------------------------------------------------
# BigVGAN v2 activations (Snake / SnakeBeta)
# ---------------------------------------------------------------------------
class Snake(nn.Module):
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
):
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.alpha.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x):
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
if self.alpha_logscale:
a = torch.exp(a)
return x + (1.0 / (a + self.eps)) * torch.sin(x * a).pow(2)
class SnakeBeta(nn.Module):
def __init__(
self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True
):
super().__init__()
self.alpha_logscale = alpha_logscale
self.alpha = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.alpha.requires_grad = alpha_trainable
self.beta = nn.Parameter(
torch.zeros(in_features)
if alpha_logscale
else torch.ones(in_features) * alpha
)
self.beta.requires_grad = alpha_trainable
self.eps = 1e-9
def forward(self, x):
a = comfy.model_management.cast_to(self.alpha.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
b = comfy.model_management.cast_to(self.beta.unsqueeze(0).unsqueeze(-1), dtype=x.dtype, device=x.device)
if self.alpha_logscale:
a = torch.exp(a)
b = torch.exp(b)
return x + (1.0 / (b + self.eps)) * torch.sin(x * a).pow(2)
# ---------------------------------------------------------------------------
# BigVGAN v2 AMPBlock (Anti-aliased Multi-Periodicity)
# ---------------------------------------------------------------------------
class AMPBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5), activation="snake"):
super().__init__()
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.convs1 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[0],
padding=get_padding(kernel_size, dilation[0]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[1],
padding=get_padding(kernel_size, dilation[1]),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=dilation[2],
padding=get_padding(kernel_size, dilation[2]),
),
]
)
self.convs2 = nn.ModuleList(
[
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
ops.Conv1d(
channels,
channels,
kernel_size,
1,
dilation=1,
padding=get_padding(kernel_size, 1),
),
]
)
self.acts1 = nn.ModuleList(
[Activation1d(act_cls(channels)) for _ in range(len(self.convs1))]
)
self.acts2 = nn.ModuleList(
[Activation1d(act_cls(channels)) for _ in range(len(self.convs2))]
)
def forward(self, x):
for c1, c2, a1, a2 in zip(self.convs1, self.convs2, self.acts1, self.acts2):
xt = a1(x)
xt = c1(xt)
xt = a2(xt)
xt = c2(xt)
x = x + xt
return x
# ---------------------------------------------------------------------------
# HiFi-GAN residual blocks
# ---------------------------------------------------------------------------
class ResBlock1(torch.nn.Module):
def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
super(ResBlock1, self).__init__()
@ -119,6 +422,7 @@ class Vocoder(torch.nn.Module):
"""
Vocoder model for synthesizing audio from spectrograms, based on: https://github.com/jik876/hifi-gan.
Supports both HiFi-GAN (resblock "1"/"2") and BigVGAN v2 (resblock "AMP1").
"""
def __init__(self, config=None):
@ -128,19 +432,39 @@ class Vocoder(torch.nn.Module):
config = self.get_default_config()
resblock_kernel_sizes = config.get("resblock_kernel_sizes", [3, 7, 11])
upsample_rates = config.get("upsample_rates", [6, 5, 2, 2, 2])
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 15, 8, 4, 4])
upsample_rates = config.get("upsample_rates", [5, 4, 2, 2, 2])
upsample_kernel_sizes = config.get("upsample_kernel_sizes", [16, 16, 8, 4, 4])
resblock_dilation_sizes = config.get("resblock_dilation_sizes", [[1, 3, 5], [1, 3, 5], [1, 3, 5]])
upsample_initial_channel = config.get("upsample_initial_channel", 1024)
stereo = config.get("stereo", True)
resblock = config.get("resblock", "1")
activation = config.get("activation", "snake")
use_bias_at_final = config.get("use_bias_at_final", True)
# "output_sample_rate" is not present in recent checkpoint configs.
# When absent (None), AudioVAE.output_sample_rate computes it as:
# sample_rate * vocoder.upsample_factor / mel_hop_length
# where upsample_factor = product of all upsample stride lengths,
# and mel_hop_length is loaded from the autoencoder config at
# preprocessing.stft.hop_length (see CausalAudioAutoencoder).
self.output_sample_rate = config.get("output_sample_rate")
self.resblock = config.get("resblock", "1")
self.use_tanh_at_final = config.get("use_tanh_at_final", True)
self.apply_final_activation = config.get("apply_final_activation", True)
self.num_kernels = len(resblock_kernel_sizes)
self.num_upsamples = len(upsample_rates)
in_channels = 128 if stereo else 64
self.conv_pre = ops.Conv1d(in_channels, upsample_initial_channel, 7, 1, padding=3)
resblock_class = ResBlock1 if resblock == "1" else ResBlock2
if self.resblock == "1":
resblock_cls = ResBlock1
elif self.resblock == "2":
resblock_cls = ResBlock2
elif self.resblock == "AMP1":
resblock_cls = AMPBlock1
else:
raise ValueError(f"Unknown resblock type: {self.resblock}")
self.ups = nn.ModuleList()
for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
@ -157,25 +481,40 @@ class Vocoder(torch.nn.Module):
self.resblocks = nn.ModuleList()
for i in range(len(self.ups)):
ch = upsample_initial_channel // (2 ** (i + 1))
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
self.resblocks.append(resblock_class(ch, k, d))
for k, d in zip(resblock_kernel_sizes, resblock_dilation_sizes):
if self.resblock == "AMP1":
self.resblocks.append(resblock_cls(ch, k, d, activation=activation))
else:
self.resblocks.append(resblock_cls(ch, k, d))
out_channels = 2 if stereo else 1
self.conv_post = ops.Conv1d(ch, out_channels, 7, 1, padding=3)
if self.resblock == "AMP1":
act_cls = SnakeBeta if activation == "snakebeta" else Snake
self.act_post = Activation1d(act_cls(ch))
else:
self.act_post = nn.LeakyReLU()
self.conv_post = ops.Conv1d(
ch, out_channels, 7, 1, padding=3, bias=use_bias_at_final
)
self.upsample_factor = np.prod([self.ups[i].stride[0] for i in range(len(self.ups))])
def get_default_config(self):
"""Generate default configuration for the vocoder."""
config = {
"resblock_kernel_sizes": [3, 7, 11],
"upsample_rates": [6, 5, 2, 2, 2],
"upsample_kernel_sizes": [16, 15, 8, 4, 4],
"upsample_rates": [5, 4, 2, 2, 2],
"upsample_kernel_sizes": [16, 16, 8, 4, 4],
"resblock_dilation_sizes": [[1, 3, 5], [1, 3, 5], [1, 3, 5]],
"upsample_initial_channel": 1024,
"stereo": True,
"resblock": "1",
"activation": "snake",
"use_bias_at_final": True,
"use_tanh_at_final": True,
}
return config
@ -196,8 +535,10 @@ class Vocoder(torch.nn.Module):
assert x.shape[1] == 2, "Input must have 2 channels for stereo"
x = torch.cat((x[:, 0, :, :], x[:, 1, :, :]), dim=1)
x = self.conv_pre(x)
for i in range(self.num_upsamples):
x = F.leaky_relu(x, LRELU_SLOPE)
if self.resblock != "AMP1":
x = F.leaky_relu(x, LRELU_SLOPE)
x = self.ups[i](x)
xs = None
for j in range(self.num_kernels):
@ -206,8 +547,167 @@ class Vocoder(torch.nn.Module):
else:
xs += self.resblocks[i * self.num_kernels + j](x)
x = xs / self.num_kernels
x = F.leaky_relu(x)
x = self.act_post(x)
x = self.conv_post(x)
x = torch.tanh(x)
if self.apply_final_activation:
if self.use_tanh_at_final:
x = torch.tanh(x)
else:
x = torch.clamp(x, -1, 1)
return x
class _STFTFn(nn.Module):
"""Implements STFT as a convolution with precomputed DFT × Hann-window bases.
The DFT basis rows (real and imaginary parts interleaved) multiplied by the causal
Hann window are stored as buffers and loaded from the checkpoint. Using the exact
bfloat16 bases from training ensures the mel values fed to the BWE generator are
bit-identical to what it was trained on.
"""
def __init__(self, filter_length: int, hop_length: int, win_length: int):
super().__init__()
self.hop_length = hop_length
self.win_length = win_length
n_freqs = filter_length // 2 + 1
self.register_buffer("forward_basis", torch.zeros(n_freqs * 2, 1, filter_length))
self.register_buffer("inverse_basis", torch.zeros(n_freqs * 2, 1, filter_length))
def forward(self, y: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
"""Compute magnitude and phase spectrogram from a batch of waveforms.
Applies causal (left-only) padding of win_length - hop_length samples so that
each output frame depends only on past and present input no lookahead.
The STFT is computed by convolving the padded signal with forward_basis.
Args:
y: Waveform tensor of shape (B, T).
Returns:
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
Computed in float32 for numerical stability, then cast back to
the input dtype.
"""
if y.dim() == 2:
y = y.unsqueeze(1) # (B, 1, T)
left_pad = max(0, self.win_length - self.hop_length) # causal: left-only
y = F.pad(y, (left_pad, 0))
spec = F.conv1d(y, comfy.model_management.cast_to(self.forward_basis, dtype=y.dtype, device=y.device), stride=self.hop_length, padding=0)
n_freqs = spec.shape[1] // 2
real, imag = spec[:, :n_freqs], spec[:, n_freqs:]
magnitude = torch.sqrt(real ** 2 + imag ** 2)
phase = torch.atan2(imag.float(), real.float()).to(real.dtype)
return magnitude, phase
class MelSTFT(nn.Module):
"""Causal log-mel spectrogram module whose buffers are loaded from the checkpoint.
Computes a log-mel spectrogram by running the causal STFT (_STFTFn) on the input
waveform and projecting the linear magnitude spectrum onto the mel filterbank.
The module's state dict layout matches the 'mel_stft.*' keys stored in the checkpoint
(mel_basis, stft_fn.forward_basis, stft_fn.inverse_basis).
"""
def __init__(
self,
filter_length: int,
hop_length: int,
win_length: int,
n_mel_channels: int,
sampling_rate: int,
mel_fmin: float,
mel_fmax: float,
):
super().__init__()
self.stft_fn = _STFTFn(filter_length, hop_length, win_length)
n_freqs = filter_length // 2 + 1
self.register_buffer("mel_basis", torch.zeros(n_mel_channels, n_freqs))
def mel_spectrogram(
self, y: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
"""Compute log-mel spectrogram and auxiliary spectral quantities.
Args:
y: Waveform tensor of shape (B, T).
Returns:
log_mel: Log-compressed mel spectrogram, shape (B, n_mel_channels, T_frames).
Computed as log(clamp(mel_basis @ magnitude, min=1e-5)).
magnitude: Linear amplitude spectrogram, shape (B, n_freqs, T_frames).
phase: Phase spectrogram in radians, shape (B, n_freqs, T_frames).
energy: Per-frame energy (L2 norm over frequency), shape (B, T_frames).
"""
magnitude, phase = self.stft_fn(y)
energy = torch.norm(magnitude, dim=1)
mel = torch.matmul(comfy.model_management.cast_to(self.mel_basis, dtype=magnitude.dtype, device=y.device), magnitude)
log_mel = torch.log(torch.clamp(mel, min=1e-5))
return log_mel, magnitude, phase, energy
class VocoderWithBWE(torch.nn.Module):
"""Vocoder with bandwidth extension (BWE) for higher sample rate output.
Chains a base vocoder (mel low-rate waveform) with a BWE stage that upsamples
to a higher rate. The BWE computes a mel spectrogram from the low-rate waveform.
"""
def __init__(self, config):
super().__init__()
vocoder_config = config["vocoder"]
bwe_config = config["bwe"]
self.vocoder = Vocoder(config=vocoder_config)
self.bwe_generator = Vocoder(
config={**bwe_config, "apply_final_activation": False}
)
self.input_sample_rate = bwe_config["input_sampling_rate"]
self.output_sample_rate = bwe_config["output_sampling_rate"]
self.hop_length = bwe_config["hop_length"]
self.mel_stft = MelSTFT(
filter_length=bwe_config["n_fft"],
hop_length=bwe_config["hop_length"],
win_length=bwe_config["n_fft"],
n_mel_channels=bwe_config["num_mels"],
sampling_rate=bwe_config["input_sampling_rate"],
mel_fmin=0.0,
mel_fmax=bwe_config["input_sampling_rate"] / 2.0,
)
self.resampler = UpSample1d(
ratio=bwe_config["output_sampling_rate"] // bwe_config["input_sampling_rate"],
persistent=False,
window_type="hann",
)
def _compute_mel(self, audio):
"""Compute log-mel spectrogram from waveform using causal STFT bases."""
B, C, T = audio.shape
flat = audio.reshape(B * C, -1) # (B*C, T)
mel, _, _, _ = self.mel_stft.mel_spectrogram(flat) # (B*C, n_mels, T_frames)
return mel.reshape(B, C, mel.shape[1], mel.shape[2]) # (B, C, n_mels, T_frames)
def forward(self, mel_spec):
x = self.vocoder(mel_spec)
_, _, T_low = x.shape
T_out = T_low * self.output_sample_rate // self.input_sample_rate
remainder = T_low % self.hop_length
if remainder != 0:
x = F.pad(x, (0, self.hop_length - remainder))
mel = self._compute_mel(x)
residual = self.bwe_generator(mel)
skip = self.resampler(x)
assert residual.shape == skip.shape, f"residual {residual.shape} != skip {skip.shape}"
return torch.clamp(residual + skip, -1, 1)[..., :T_out]

View File

@ -14,6 +14,7 @@ from comfy.ldm.flux.layers import EmbedND
from comfy.ldm.flux.math import apply_rope
import comfy.patcher_extension
import comfy.utils
from comfy.ldm.chroma_radiance.layers import NerfEmbedder
def invert_slices(slices, length):
@ -858,3 +859,267 @@ class NextDiT(nn.Module):
img = self.unpatchify(img, img_size, cap_size, return_tensor=x_is_tensor)[:, :, :h, :w]
return -img
#############################################################################
# Pixel Space Decoder Components #
#############################################################################
def _modulate_shift_scale(x, shift, scale):
return x * (1 + scale) + shift
class PixelResBlock(nn.Module):
"""
Residual block with AdaLN modulation, zero-initialised so it starts as
an identity at the beginning of training.
"""
def __init__(self, channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.in_ln = operations.LayerNorm(channels, eps=1e-6, dtype=dtype, device=device)
self.mlp = nn.Sequential(
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
nn.SiLU(),
operations.Linear(channels, channels, bias=True, dtype=dtype, device=device),
)
self.adaLN_modulation = nn.Sequential(
nn.SiLU(),
operations.Linear(channels, 3 * channels, bias=True, dtype=dtype, device=device),
)
def forward(self, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
shift, scale, gate = self.adaLN_modulation(y).chunk(3, dim=-1)
h = _modulate_shift_scale(self.in_ln(x), shift, scale)
h = self.mlp(h)
return x + gate * h
class DCTFinalLayer(nn.Module):
"""Zero-initialised output projection (adopted from DiT)."""
def __init__(self, model_channels: int, out_channels: int, dtype=None, device=None, operations=None):
super().__init__()
self.norm_final = operations.LayerNorm(model_channels, elementwise_affine=False, eps=1e-6, dtype=dtype, device=device)
self.linear = operations.Linear(model_channels, out_channels, bias=True, dtype=dtype, device=device)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(self.norm_final(x))
class SimpleMLPAdaLN(nn.Module):
"""
Small MLP decoder head for the pixel-space variant.
Takes per-patch pixel values and a per-patch conditioning vector from the
transformer backbone and predicts the denoised pixel values.
x : [B*N, P^2, C] noisy pixel values per patch position
c : [B*N, dim] backbone hidden state per patch (conditioning)
[B*N, P^2, C]
"""
def __init__(
self,
in_channels: int,
model_channels: int,
out_channels: int,
z_channels: int,
num_res_blocks: int,
max_freqs: int = 8,
dtype=None,
device=None,
operations=None,
):
super().__init__()
self.dtype = dtype
# Project backbone hidden state → per-patch conditioning
self.cond_embed = operations.Linear(z_channels, model_channels, dtype=dtype, device=device)
# Input projection with DCT positional encoding
self.input_embedder = NerfEmbedder(
in_channels=in_channels,
hidden_size_input=model_channels,
max_freqs=max_freqs,
dtype=dtype,
device=device,
operations=operations,
)
# Residual blocks
self.res_blocks = nn.ModuleList([
PixelResBlock(model_channels, dtype=dtype, device=device, operations=operations) for _ in range(num_res_blocks)
])
# Output projection
self.final_layer = DCTFinalLayer(model_channels, out_channels, dtype=dtype, device=device, operations=operations)
def forward(self, x: torch.Tensor, c: torch.Tensor) -> torch.Tensor:
# x: [B*N, 1, P^2*C], c: [B*N, dim]
original_dtype = x.dtype
weight_dtype = self.cond_embed.weight.dtype if hasattr(self.cond_embed, "weight") and self.cond_embed.weight is not None else (self.dtype or x.dtype)
x = self.input_embedder(x) # [B*N, 1, model_channels]
y = self.cond_embed(c.to(weight_dtype)).unsqueeze(1) # [B*N, 1, model_channels]
x = x.to(weight_dtype)
for block in self.res_blocks:
x = block(x, y)
return self.final_layer(x).to(original_dtype) # [B*N, 1, P^2*C]
#############################################################################
# NextDiT Pixel Space #
#############################################################################
class NextDiTPixelSpace(NextDiT):
"""
Pixel-space variant of NextDiT.
Identical transformer backbone to NextDiT, but the output head is replaced
with a small MLP decoder (SimpleMLPAdaLN) that operates on raw pixel values
per patch rather than a single affine projection.
Key differences vs NextDiT:
``final_layer`` is removed; ``dec_net`` (SimpleMLPAdaLN) is used instead.
``_forward`` stores the raw patchified pixel values before the backbone
embedding and feeds them to ``dec_net`` together with the per-patch
backbone hidden states.
Supports optional x0 prediction via ``use_x0``.
"""
def __init__(
self,
# decoder-specific
decoder_hidden_size: int = 3840,
decoder_num_res_blocks: int = 4,
decoder_max_freqs: int = 8,
decoder_in_channels: int = None, # full flattened patch size (patch_size^2 * in_channels)
use_x0: bool = False,
# all NextDiT args forwarded unchanged
**kwargs,
):
super().__init__(**kwargs)
# Remove the latent-space final layer not used in pixel space
del self.final_layer
patch_size = kwargs.get("patch_size", 2)
in_channels = kwargs.get("in_channels", 4)
dim = kwargs.get("dim", 4096)
# decoder_in_channels is the full flattened patch: patch_size^2 * in_channels
dec_in_ch = decoder_in_channels if decoder_in_channels is not None else patch_size ** 2 * in_channels
self.dec_net = SimpleMLPAdaLN(
in_channels=dec_in_ch,
model_channels=decoder_hidden_size,
out_channels=dec_in_ch,
z_channels=dim,
num_res_blocks=decoder_num_res_blocks,
max_freqs=decoder_max_freqs,
dtype=kwargs.get("dtype"),
device=kwargs.get("device"),
operations=kwargs.get("operations"),
)
if use_x0:
self.register_buffer("__x0__", torch.tensor([]))
# ------------------------------------------------------------------
# Forward — mirrors NextDiT._forward exactly, replacing final_layer
# with the pixel-space dec_net decoder.
# ------------------------------------------------------------------
def _forward(self, x, timesteps, context, num_tokens, attention_mask=None, ref_latents=[], ref_contexts=[], siglip_feats=[], transformer_options={}, **kwargs):
omni = len(ref_latents) > 0
if omni:
timesteps = torch.cat([timesteps * 0, timesteps], dim=0)
t = 1.0 - timesteps
cap_feats = context
cap_mask = attention_mask
bs, c, h, w = x.shape
x = comfy.ldm.common_dit.pad_to_patch_size(x, (self.patch_size, self.patch_size))
t = self.t_embedder(t * self.time_scale, dtype=x.dtype)
adaln_input = t
if self.clip_text_pooled_proj is not None:
pooled = kwargs.get("clip_text_pooled", None)
if pooled is not None:
pooled = self.clip_text_pooled_proj(pooled)
else:
pooled = torch.zeros((x.shape[0], self.clip_text_dim), device=x.device, dtype=x.dtype)
adaln_input = self.time_text_embed(torch.cat((t, pooled), dim=-1))
# ---- capture raw pixel patches before patchify_and_embed embeds them ----
pH = pW = self.patch_size
B, C, H, W = x.shape
pixel_patches = (
x.view(B, C, H // pH, pH, W // pW, pW)
.permute(0, 2, 4, 3, 5, 1) # [B, Ht, Wt, pH, pW, C]
.flatten(3) # [B, Ht, Wt, pH*pW*C]
.flatten(1, 2) # [B, N, pH*pW*C]
)
N = pixel_patches.shape[1]
# decoder sees one token per patch: [B*N, 1, P^2*C]
pixel_values = pixel_patches.reshape(B * N, 1, pH * pW * C)
patches = transformer_options.get("patches", {})
x_is_tensor = isinstance(x, torch.Tensor)
img, mask, img_size, cap_size, freqs_cis, timestep_zero_index = self.patchify_and_embed(
x, cap_feats, cap_mask, adaln_input, num_tokens,
ref_latents=ref_latents, ref_contexts=ref_contexts,
siglip_feats=siglip_feats, transformer_options=transformer_options
)
freqs_cis = freqs_cis.to(img.device)
transformer_options["total_blocks"] = len(self.layers)
transformer_options["block_type"] = "double"
img_input = img
for i, layer in enumerate(self.layers):
transformer_options["block_index"] = i
img = layer(img, mask, freqs_cis, adaln_input, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options)
if "double_block" in patches:
for p in patches["double_block"]:
out = p({"img": img[:, cap_size[0]:], "img_input": img_input[:, cap_size[0]:], "txt": img[:, :cap_size[0]], "pe": freqs_cis[:, cap_size[0]:], "vec": adaln_input, "x": x, "block_index": i, "transformer_options": transformer_options})
if "img" in out:
img[:, cap_size[0]:] = out["img"]
if "txt" in out:
img[:, :cap_size[0]] = out["txt"]
# ---- pixel-space decoder (replaces final_layer + unpatchify) ----
# img may have padding tokens beyond N; only the first N are real image patches
img_hidden = img[:, cap_size[0]:cap_size[0] + N, :] # [B, N, dim]
decoder_cond = img_hidden.reshape(B * N, self.dim) # [B*N, dim]
output = self.dec_net(pixel_values, decoder_cond) # [B*N, 1, P^2*C]
output = output.reshape(B, N, -1) # [B, N, P^2*C]
# prepend zero cap placeholder so unpatchify indexing works unchanged
cap_placeholder = torch.zeros(
B, cap_size[0], output.shape[-1], device=output.device, dtype=output.dtype
)
img_out = self.unpatchify(
torch.cat([cap_placeholder, output], dim=1),
img_size, cap_size, return_tensor=x_is_tensor
)[:, :, :h, :w]
return -img_out
def forward(self, x, timesteps, context, num_tokens, attention_mask=None, **kwargs):
# _forward returns neg_x0 = -x0 (negated decoder output).
#
# Reference inference (working_inference_reference.py):
# out = _forward(img, t) # = -x0
# pred = (img - out) / t # = (img + x0) / t [_apply_x0_residual]
# img += (t_prev - t_curr) * pred # Euler step
#
# ComfyUI's Euler sampler does the same:
# x_next = x + (sigma_next - sigma) * model_output
# So model_output must equal pred = (x - neg_x0) / t = (x - (-x0)) / t = (x + x0) / t
neg_x0 = comfy.patcher_extension.WrapperExecutor.new_class_executor(
self._forward,
self,
comfy.patcher_extension.get_all_wrappers(comfy.patcher_extension.WrappersMP.DIFFUSION_MODEL, kwargs.get("transformer_options", {}))
).execute(x, timesteps, context, num_tokens, attention_mask, **kwargs)
return (x - neg_x0) / timesteps.view(-1, 1, 1, 1)

View File

@ -372,7 +372,8 @@ def attention_split(q, k, v, heads, mask=None, attn_precision=None, skip_reshape
r1[:, i:end] = einsum('b i j, b j d -> b i d', s2, v)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
model_management.raise_non_oom(e)
if first_op_done == False:
model_management.soft_empty_cache(True)
if cleared_cache == False:

View File

@ -258,7 +258,8 @@ def slice_attention(q, k, v):
r1[:, :, i:end] = torch.bmm(v, s2)
del s2
break
except model_management.OOM_EXCEPTION as e:
except Exception as e:
model_management.raise_non_oom(e)
model_management.soft_empty_cache(True)
steps *= 2
if steps > 128:
@ -314,7 +315,8 @@ def pytorch_attention(q, k, v):
try:
out = comfy.ops.scaled_dot_product_attention(q, k, v, attn_mask=None, dropout_p=0.0, is_causal=False)
out = out.transpose(2, 3).reshape(orig_shape)
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("scaled_dot_product_attention OOMed: switched to slice attention")
oom_fallback = True
if oom_fallback:

View File

@ -169,7 +169,8 @@ def _get_attention_scores_no_kv_chunking(
try:
attn_probs = attn_scores.softmax(dim=-1)
del attn_scores
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("ran out of memory while running softmax in _get_attention_scores_no_kv_chunking, trying slower in place softmax instead")
attn_scores -= attn_scores.max(dim=-1, keepdim=True).values # noqa: F821 attn_scores is not defined
torch.exp(attn_scores, out=attn_scores)

View File

@ -149,6 +149,9 @@ class Attention(nn.Module):
seq_img = hidden_states.shape[1]
seq_txt = encoder_hidden_states.shape[1]
transformer_patches = transformer_options.get("patches", {})
extra_options = transformer_options.copy()
# Project and reshape to BHND format (batch, heads, seq, dim)
img_query = self.to_q(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
img_key = self.to_k(hidden_states).view(batch_size, seq_img, self.heads, -1).transpose(1, 2).contiguous()
@ -167,15 +170,22 @@ class Attention(nn.Module):
joint_key = torch.cat([txt_key, img_key], dim=2)
joint_value = torch.cat([txt_value, img_value], dim=2)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
if encoder_hidden_states_mask is not None:
attn_mask = torch.zeros((batch_size, 1, seq_txt + seq_img), dtype=hidden_states.dtype, device=hidden_states.device)
attn_mask[:, 0, :seq_txt] = encoder_hidden_states_mask
else:
attn_mask = None
extra_options["img_slice"] = [txt_query.shape[2], joint_query.shape[2]]
if "attn1_patch" in transformer_patches:
patch = transformer_patches["attn1_patch"]
for p in patch:
out = p(joint_query, joint_key, joint_value, pe=image_rotary_emb, attn_mask=encoder_hidden_states_mask, extra_options=extra_options)
joint_query, joint_key, joint_value, image_rotary_emb, encoder_hidden_states_mask = out.get("q", joint_query), out.get("k", joint_key), out.get("v", joint_value), out.get("pe", image_rotary_emb), out.get("attn_mask", encoder_hidden_states_mask)
joint_query = apply_rope1(joint_query, image_rotary_emb)
joint_key = apply_rope1(joint_key, image_rotary_emb)
joint_hidden_states = optimized_attention_masked(joint_query, joint_key, joint_value, self.heads,
attn_mask, transformer_options=transformer_options,
skip_reshape=True)
@ -444,6 +454,7 @@ class QwenImageTransformer2DModel(nn.Module):
timestep_zero_index = None
if ref_latents is not None:
ref_num_tokens = []
h = 0
w = 0
index = 0
@ -474,16 +485,16 @@ class QwenImageTransformer2DModel(nn.Module):
kontext, kontext_ids, _ = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset)
hidden_states = torch.cat([hidden_states, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = num_embeds
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_start = round(max(((x.shape[-1] + (self.patch_size // 2)) // self.patch_size) // 2, ((x.shape[-2] + (self.patch_size // 2)) // self.patch_size) // 2))
txt_ids = torch.arange(txt_start, txt_start + context.shape[1], device=x.device).reshape(1, -1, 1).repeat(x.shape[0], 1, 3)
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
hidden_states = self.img_in(hidden_states)
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
@ -495,6 +506,18 @@ class QwenImageTransformer2DModel(nn.Module):
patches = transformer_options.get("patches", {})
blocks_replace = patches_replace.get("dit", {})
if "post_input" in patches:
for p in patches["post_input"]:
out = p({"img": hidden_states, "txt": encoder_hidden_states, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
hidden_states = out["img"]
encoder_hidden_states = out["txt"]
img_ids = out["img_ids"]
txt_ids = out["txt_ids"]
ids = torch.cat((txt_ids, img_ids), dim=1)
image_rotary_emb = self.pe_embedder(ids).to(x.dtype).contiguous()
del ids, txt_ids, img_ids
transformer_options["total_blocks"] = len(self.transformer_blocks)
transformer_options["block_type"] = "double"
for i, block in enumerate(self.transformer_blocks):

View File

@ -99,6 +99,9 @@ def model_lora_keys_clip(model, key_map={}):
for k in sdk:
if k.endswith(".weight"):
key_map["text_encoders.{}".format(k[:-len(".weight")])] = k #generic lora format without any weird key names
tp = k.find(".transformer.") #also map without wrapper prefix for composite text encoder models
if tp > 0 and not k.startswith("clip_"):
key_map["text_encoders.{}".format(k[tp + 1:-len(".weight")])] = k
text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
clip_l_present = False

View File

@ -1034,7 +1034,7 @@ class LTXAV(BaseModel):
cross_attn = kwargs.get("cross_attn", None)
if cross_attn is not None:
if hasattr(self.diffusion_model, "preprocess_text_embeds"):
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()))
cross_attn = self.diffusion_model.preprocess_text_embeds(cross_attn.to(device=device, dtype=self.get_dtype_inference()), unprocessed=kwargs.get("unprocessed_ltxav_embeds", False))
out['c_crossattn'] = comfy.conds.CONDRegular(cross_attn)
out['frame_rate'] = comfy.conds.CONDConstant(kwargs.get("frame_rate", 25))
@ -1276,6 +1276,11 @@ class Lumina2(BaseModel):
out['ref_latents'] = list([1, 16, sum(map(lambda a: math.prod(a.size()[2:]), ref_latents))])
return out
class ZImagePixelSpace(Lumina2):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
BaseModel.__init__(self, model_config, model_type, device=device, unet_model=comfy.ldm.lumina.model.NextDiTPixelSpace)
self.memory_usage_factor_conds = ("ref_latents",)
class WAN21(BaseModel):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

View File

@ -1,4 +1,5 @@
import json
import comfy.memory_management
import comfy.supported_models
import comfy.supported_models_base
import comfy.utils
@ -464,6 +465,29 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
if sig_weight is not None:
dit_config["siglip_feat_dim"] = sig_weight.shape[0]
dec_cond_key = '{}dec_net.cond_embed.weight'.format(key_prefix)
if dec_cond_key in state_dict_keys: # pixel-space variant
dit_config["image_model"] = "zimage_pixel"
# patch_size and in_channels are derived from x_embedder:
# x_embedder: Linear(patch_size * patch_size * in_channels, dim)
# The decoder also receives the full flat patch, so decoder_in_channels = x_embedder input dim.
x_emb_in = state_dict['{}x_embedder.weight'.format(key_prefix)].shape[1]
dec_out = state_dict['{}dec_net.final_layer.linear.weight'.format(key_prefix)].shape[0]
# patch_size: infer from decoder final layer output matching x_embedder input
# in_channels: infer from dec_net input_embedder (in_features = dec_in_ch + max_freqs^2)
embedder_w = state_dict['{}dec_net.input_embedder.embedder.0.weight'.format(key_prefix)]
dec_in_ch = dec_out # decoder in == decoder out (same pixel space)
dit_config["patch_size"] = round((x_emb_in / 3) ** 0.5) # assume RGB (in_channels=3)
dit_config["in_channels"] = 3
dit_config["decoder_in_channels"] = dec_in_ch
dit_config["decoder_hidden_size"] = state_dict[dec_cond_key].shape[0]
dit_config["decoder_num_res_blocks"] = count_blocks(
state_dict_keys, '{}dec_net.res_blocks.'.format(key_prefix) + '{}.'
)
dit_config["decoder_max_freqs"] = int((embedder_w.shape[1] - dec_in_ch) ** 0.5)
if '{}__x0__'.format(key_prefix) in state_dict_keys:
dit_config["use_x0"] = True
return dit_config
if '{}head.modulation'.format(key_prefix) in state_dict_keys: # Wan 2.1
@ -1095,8 +1119,13 @@ def convert_diffusers_mmdit(state_dict, output_prefix=""):
new[:old_weight.shape[0]] = old_weight
old_weight = new
if old_weight is out_sd.get(t[0], None) and comfy.memory_management.aimdo_enabled:
old_weight = old_weight.clone()
w = old_weight.narrow(offset[0], offset[1], offset[2])
else:
if comfy.memory_management.aimdo_enabled:
weight = weight.clone()
old_weight = weight
w = weight
w[:] = fun(weight)

View File

@ -32,9 +32,6 @@ import comfy.memory_management
import comfy.utils
import comfy.quant_ops
import comfy_aimdo.torch
import comfy_aimdo.model_vbar
class VRAMState(Enum):
DISABLED = 0 #No vram present: no need to move models to vram
NO_VRAM = 1 #Very low vram: enable all the options to save vram
@ -273,6 +270,23 @@ try:
except:
OOM_EXCEPTION = Exception
try:
ACCELERATOR_ERROR = torch.AcceleratorError
except AttributeError:
ACCELERATOR_ERROR = RuntimeError
def is_oom(e):
if isinstance(e, OOM_EXCEPTION):
return True
if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
discard_cuda_async_error()
return True
return False
def raise_non_oom(e):
if not is_oom(e):
raise e
XFORMERS_VERSION = ""
XFORMERS_ENABLED_VAE = True
if args.disable_xformers:
@ -867,6 +881,8 @@ def archive_model_dtypes(model):
for name, module in model.named_modules():
for param_name, param in module.named_parameters(recurse=False):
setattr(module, f"{param_name}_comfy_model_dtype", param.dtype)
for buf_name, buf in module.named_buffers(recurse=False):
setattr(module, f"{buf_name}_comfy_model_dtype", buf.dtype)
def cleanup_models():
@ -908,11 +924,14 @@ def unet_offload_device():
return torch.device("cpu")
def unet_inital_load_device(parameters, dtype):
cpu_dev = torch.device("cpu")
if comfy.memory_management.aimdo_enabled:
return cpu_dev
torch_dev = get_torch_device()
if vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.SHARED:
return torch_dev
cpu_dev = torch.device("cpu")
if DISABLE_SMART_MEMORY or vram_state == VRAMState.NO_VRAM:
return cpu_dev
@ -920,7 +939,7 @@ def unet_inital_load_device(parameters, dtype):
mem_dev = get_free_memory(torch_dev)
mem_cpu = get_free_memory(cpu_dev)
if mem_dev > mem_cpu and model_size < mem_dev and comfy.memory_management.aimdo_enabled:
if mem_dev > mem_cpu and model_size < mem_dev:
return torch_dev
else:
return cpu_dev
@ -1014,7 +1033,7 @@ def text_encoder_offload_device():
def text_encoder_device():
if args.gpu_only:
return get_torch_device()
elif vram_state == VRAMState.HIGH_VRAM or vram_state == VRAMState.NORMAL_VRAM:
elif vram_state in (VRAMState.HIGH_VRAM, VRAMState.NORMAL_VRAM) or comfy.memory_management.aimdo_enabled:
if should_use_fp16(prioritize_performance=False):
return get_torch_device()
else:
@ -1023,6 +1042,9 @@ def text_encoder_device():
return torch.device("cpu")
def text_encoder_initial_device(load_device, offload_device, model_size=0):
if comfy.memory_management.aimdo_enabled:
return offload_device
if load_device == offload_device or model_size <= 1024 * 1024 * 1024:
return offload_device
@ -1220,6 +1242,7 @@ def reset_cast_buffers():
LARGEST_CASTED_WEIGHT = (None, 0)
for offload_stream in STREAM_CAST_BUFFERS:
offload_stream.synchronize()
synchronize()
STREAM_CAST_BUFFERS.clear()
soft_empty_cache()
@ -1283,43 +1306,6 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
def cast_to(weight, dtype=None, device=None, non_blocking=False, copy=False, stream=None, r=None):
if hasattr(weight, "_v"):
#Unexpected usage patterns. There is no reason these don't work but they
#have no testing and no callers do this.
assert r is None
assert stream is None
cast_geometry = comfy.memory_management.tensors_to_geometries([ weight ])
if dtype is None:
dtype = weight._model_dtype
signature = comfy_aimdo.model_vbar.vbar_fault(weight._v)
if signature is not None:
if comfy_aimdo.model_vbar.vbar_signature_compare(signature, weight._v_signature):
v_tensor = weight._v_tensor
else:
raw_tensor = comfy_aimdo.torch.aimdo_to_tensor(weight._v, device)
v_tensor = comfy.memory_management.interpret_gathered_like(cast_geometry, raw_tensor)[0]
weight._v_tensor = v_tensor
weight._v_signature = signature
#Send it over
v_tensor.copy_(weight, non_blocking=non_blocking)
return v_tensor.to(dtype=dtype)
r = torch.empty_like(weight, dtype=dtype, device=device)
if weight.dtype != r.dtype and weight.dtype != weight._model_dtype:
#Offloaded casting could skip this, however it would make the quantizations
#inconsistent between loaded and offloaded weights. So force the double casting
#that would happen in regular flow to make offload deterministic.
cast_buffer = torch.empty_like(weight, dtype=weight._model_dtype, device=device)
cast_buffer.copy_(weight, non_blocking=non_blocking)
weight = cast_buffer
r.copy_(weight, non_blocking=non_blocking)
return r
if device is None or weight.device == device:
if not copy:
if dtype is None or weight.dtype == dtype:
@ -1371,7 +1357,7 @@ def discard_cuda_async_error():
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b
synchronize()
except torch.AcceleratorError:
except RuntimeError:
#Dump it! We already know about it from the synchronous return
pass
@ -1775,12 +1761,16 @@ def lora_compute_dtype(device):
return dtype
def synchronize():
if cpu_mode():
return
if is_intel_xpu():
torch.xpu.synchronize()
elif torch.cuda.is_available():
torch.cuda.synchronize()
def soft_empty_cache(force=False):
if cpu_mode():
return
global cpu_state
if cpu_state == CPUState.MPS:
torch.mps.empty_cache()

View File

@ -241,6 +241,7 @@ class ModelPatcher:
self.patches = {}
self.backup = {}
self.backup_buffers = {}
self.object_patches = {}
self.object_patches_backup = {}
self.weight_wrapper_patches = {}
@ -306,10 +307,16 @@ class ModelPatcher:
return self.model.lowvram_patch_counter
def get_free_memory(self, device):
return comfy.model_management.get_free_memory(device)
#Prioritize batching (incl. CFG/conds etc) over keeping the model resident. In
#the vast majority of setups a little bit of offloading on the giant model more
#than pays for CFG. So return everything both torch and Aimdo could give us
aimdo_mem = 0
if comfy.memory_management.aimdo_enabled:
aimdo_mem = comfy_aimdo.model_vbar.vbars_analyze()
return comfy.model_management.get_free_memory(device) + aimdo_mem
def get_clone_model_override(self):
return self.model, (self.backup, self.object_patches_backup, self.pinned)
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
def clone(self, disable_dynamic=False, model_override=None):
class_ = self.__class__
@ -336,7 +343,7 @@ class ModelPatcher:
n.force_cast_weights = self.force_cast_weights
n.backup, n.object_patches_backup, n.pinned = model_override[1]
n.backup, n.backup_buffers, n.object_patches_backup, n.pinned = model_override[1]
# attachments
n.attachments = {}
@ -592,6 +599,27 @@ class ModelPatcher:
return models
def model_patches_call_function(self, function_name="cleanup", arguments={}):
to = self.model_options["transformer_options"]
if "patches" in to:
patches = to["patches"]
for name in patches:
patch_list = patches[name]
for i in range(len(patch_list)):
if hasattr(patch_list[i], function_name):
getattr(patch_list[i], function_name)(**arguments)
if "patches_replace" in to:
patches = to["patches_replace"]
for name in patches:
patch_list = patches[name]
for k in patch_list:
if hasattr(patch_list[k], function_name):
getattr(patch_list[k], function_name)(**arguments)
if "model_function_wrapper" in self.model_options:
wrap_func = self.model_options["model_function_wrapper"]
if hasattr(wrap_func, function_name):
getattr(wrap_func, function_name)(**arguments)
def model_dtype(self):
if hasattr(self.model, "get_dtype"):
return self.model.get_dtype()
@ -698,7 +726,7 @@ class ModelPatcher:
for key in list(self.pinned):
self.unpin_weight(key)
def _load_list(self, prio_comfy_cast_weights=False, default_device=None):
def _load_list(self, for_dynamic=False, default_device=None):
loading = []
for n, m in self.model.named_modules():
default = False
@ -708,8 +736,8 @@ class ModelPatcher:
default = True # default random weights in non leaf modules
break
if default and default_device is not None:
for param in params.values():
param.data = param.data.to(device=default_device)
for param_name, param in params.items():
param.data = param.data.to(device=default_device, dtype=getattr(m, param_name + "_comfy_model_dtype", None))
if not default and (hasattr(m, "comfy_cast_weights") or len(params) > 0):
module_mem = comfy.model_management.module_size(m)
module_offload_mem = module_mem
@ -726,8 +754,13 @@ class ModelPatcher:
return 0
module_offload_mem += check_module_offload_mem("{}.weight".format(n))
module_offload_mem += check_module_offload_mem("{}.bias".format(n))
prepend = (not hasattr(m, "comfy_cast_weights"),) if prio_comfy_cast_weights else ()
loading.append(prepend + (module_offload_mem, module_mem, n, m, params))
# Dynamic: small weights (<64KB) first, then larger weights prioritized by size.
# Non-dynamic: prioritize by module offload cost.
if for_dynamic:
sort_criteria = (module_offload_mem >= 64 * 1024, -module_offload_mem)
else:
sort_criteria = (module_offload_mem,)
loading.append(sort_criteria + (module_mem, n, m, params))
return loading
def load(self, device_to=None, lowvram_model_memory=0, force_patch_weights=False, full_load=False):
@ -1050,6 +1083,7 @@ class ModelPatcher:
return comfy.lora.calculate_weight(patches, weight, key, intermediate_dtype=intermediate_dtype)
def cleanup(self):
self.model_patches_call_function(function_name="cleanup")
self.clean_hooks()
if hasattr(self.model, "current_patcher"):
self.model.current_patcher = None
@ -1435,10 +1469,6 @@ class ModelPatcherDynamic(ModelPatcher):
def __init__(self, model, load_device, offload_device, size=0, weight_inplace_update=False):
super().__init__(model, load_device, offload_device, size, weight_inplace_update)
#this is now way more dynamic and we dont support the same base model for both Dynamic
#and non-dynamic patchers.
if hasattr(self.model, "model_loaded_weight_memory"):
del self.model.model_loaded_weight_memory
if not hasattr(self.model, "dynamic_vbars"):
self.model.dynamic_vbars = {}
self.non_dynamic_delegate_model = None
@ -1461,15 +1491,7 @@ class ModelPatcherDynamic(ModelPatcher):
def loaded_size(self):
vbar = self._vbar_get()
if vbar is None:
return 0
return vbar.loaded_size()
def get_free_memory(self, device):
#NOTE: on high condition / batch counts, estimate should have already vacated
#all non-dynamic models so this is safe even if its not 100% true that this
#would all be avaiable for inference use.
return comfy.model_management.get_total_memory(device) - self.model_size()
return (vbar.loaded_size() if vbar is not None else 0) + self.model.model_loaded_weight_memory
#Pinning is deferred to ops time. Assert against this API to avoid pin leaks.
@ -1504,6 +1526,7 @@ class ModelPatcherDynamic(ModelPatcher):
num_patches = 0
allocated_size = 0
self.model.model_loaded_weight_memory = 0
with self.use_ejected():
self.unpatch_hooks()
@ -1512,15 +1535,11 @@ class ModelPatcherDynamic(ModelPatcher):
if vbar is not None:
vbar.prioritize()
#We force reserve VRAM for the non comfy-weight so we dont have to deal
#with pin and unpin syncrhonization which can be expensive for small weights
#with a high layer rate (e.g. autoregressive LLMs).
#prioritize the non-comfy weights (note the order reverse).
loading = self._load_list(prio_comfy_cast_weights=True, default_device=device_to)
loading.sort(reverse=True)
loading = self._load_list(for_dynamic=True, default_device=device_to)
loading.sort()
for x in loading:
_, _, _, n, m, params = x
*_, module_mem, n, m, params = x
def set_dirty(item, dirty):
if dirty or not hasattr(item, "_v_signature"):
@ -1558,6 +1577,9 @@ class ModelPatcherDynamic(ModelPatcher):
if key in self.backup:
comfy.utils.set_attr_param(self.model, key, self.backup[key].weight)
self.patch_weight_to_device(key, device_to=device_to)
weight, _, _ = get_key_weight(self.model, key)
if weight is not None:
self.model.model_loaded_weight_memory += weight.numel() * weight.element_size()
if hasattr(m, "comfy_cast_weights"):
m.comfy_cast_weights = True
@ -1583,21 +1605,26 @@ class ModelPatcherDynamic(ModelPatcher):
for param in params:
key = key_param_name_to_key(n, param)
weight, _, _ = get_key_weight(self.model, key)
weight.seed_key = key
set_dirty(weight, dirty)
geometry = weight
model_dtype = getattr(m, param + "_comfy_model_dtype", None) or weight.dtype
geometry = comfy.memory_management.TensorGeometry(shape=weight.shape, dtype=model_dtype)
weight_size = geometry.numel() * geometry.element_size()
if vbar is not None and not hasattr(weight, "_v"):
weight._v = vbar.alloc(weight_size)
weight._model_dtype = model_dtype
allocated_size += weight_size
vbar.set_watermark_limit(allocated_size)
if key not in self.backup:
self.backup[key] = collections.namedtuple('Dimension', ['weight', 'inplace_update'])(weight, False)
model_dtype = getattr(m, param + "_comfy_model_dtype", None)
casted_weight = weight.to(dtype=model_dtype, device=device_to)
comfy.utils.set_attr_param(self.model, key, casted_weight)
self.model.model_loaded_weight_memory += casted_weight.numel() * casted_weight.element_size()
move_weight_functions(m, device_to)
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.")
for key, buf in self.model.named_buffers(recurse=True):
if key not in self.backup_buffers:
self.backup_buffers[key] = buf
module, buf_name = comfy.utils.resolve_attr(self.model, key)
model_dtype = getattr(module, buf_name + "_comfy_model_dtype", None)
casted_buf = buf.to(dtype=model_dtype, device=device_to)
comfy.utils.set_attr_buffer(self.model, key, casted_buf)
self.model.model_loaded_weight_memory += casted_buf.numel() * casted_buf.element_size()
force_load_stat = f" Force pre-loaded {len(self.backup)} weights: {self.model.model_loaded_weight_memory // 1024} KB." if len(self.backup) > 0 else ""
logging.info(f"Model {self.model.__class__.__name__} prepared for dynamic VRAM loading. {allocated_size // (1024 ** 2)}MB Staged. {num_patches} patches attached.{force_load_stat}")
self.model.device = device_to
self.model.current_weight_patches_uuid = self.patches_uuid
@ -1613,12 +1640,23 @@ class ModelPatcherDynamic(ModelPatcher):
assert self.load_device != torch.device("cpu")
vbar = self._vbar_get()
return 0 if vbar is None else vbar.free_memory(memory_to_free)
freed = 0 if vbar is None else vbar.free_memory(memory_to_free)
if freed < memory_to_free:
for key in list(self.backup.keys()):
bk = self.backup.pop(key)
comfy.utils.set_attr_param(self.model, key, bk.weight)
for key in list(self.backup_buffers.keys()):
comfy.utils.set_attr_buffer(self.model, key, self.backup_buffers.pop(key))
freed += self.model.model_loaded_weight_memory
self.model.model_loaded_weight_memory = 0
return freed
def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(prio_comfy_cast_weights=True, default_device=self.offload_device)
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading:
_, _, _, _, m, _ = x
*_, m, _ = x
ram_to_unload -= comfy.pinned_memory.unpin_memory(m)
if ram_to_unload <= 0:
return
@ -1640,11 +1678,6 @@ class ModelPatcherDynamic(ModelPatcher):
for m in self.model.modules():
move_weight_functions(m, device_to)
keys = list(self.backup.keys())
for k in keys:
bk = self.backup[k]
comfy.utils.set_attr_param(self.model, k, bk.weight)
def partially_load(self, device_to, extra_memory=0, force_patch_weights=False):
assert not force_patch_weights #See above
with self.use_ejected(skip_and_inject_on_exit_only=True):

View File

@ -80,6 +80,21 @@ def cast_to_input(weight, input, non_blocking=False, copy=True):
def cast_bias_weight_with_vbar(s, dtype, device, bias_dtype, non_blocking, compute_dtype, want_requant):
#vbar doesn't support CPU weights, but some custom nodes have weird paths
#that might switch the layer to the CPU and expect it to work. We have to take
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
#If you are a custom node author reading this, please move your layer to the GPU
#or declare your ModelPatcher as CPU in the first place.
if comfy.model_management.is_device_cpu(device):
weight = s.weight.to(dtype=dtype, copy=True)
if isinstance(weight, QuantizedTensor):
weight = weight.dequantize()
bias = None
if s.bias is not None:
bias = s.bias.to(dtype=bias_dtype, copy=True)
return weight, bias, (None, None, None)
offload_stream = None
xfer_dest = None
@ -269,8 +284,8 @@ def uncast_bias_weight(s, weight, bias, offload_stream):
return
os, weight_a, bias_a = offload_stream
device=None
#FIXME: This is not good RTTI
if not isinstance(weight_a, torch.Tensor):
#FIXME: This is really bad RTTI
if weight_a is not None and not isinstance(weight_a, torch.Tensor):
comfy_aimdo.model_vbar.vbar_unpin(s._v)
device = weight_a
if os is None:
@ -660,23 +675,29 @@ class fp8_ops(manual_cast):
CUBLAS_IS_AVAILABLE = False
try:
from cublas_ops import CublasLinear
from cublas_ops import CublasLinear, cublas_half_matmul
CUBLAS_IS_AVAILABLE = True
except ImportError:
pass
if CUBLAS_IS_AVAILABLE:
class cublas_ops(disable_weight_init):
class Linear(CublasLinear, disable_weight_init.Linear):
class cublas_ops(manual_cast):
class Linear(CublasLinear, manual_cast.Linear):
def reset_parameters(self):
return None
def forward_comfy_cast_weights(self, input):
return super().forward(input)
weight, bias, offload_stream = cast_bias_weight(self, input, offloadable=True)
x = cublas_half_matmul(input, weight, bias, self._epilogue_str, self.has_bias)
uncast_bias_weight(self, weight, bias, offload_stream)
return x
def forward(self, *args, **kwargs):
return super().forward(*args, **kwargs)
run_every_op()
if self.comfy_cast_weights or len(self.weight_function) > 0 or len(self.bias_function) > 0:
return self.forward_comfy_cast_weights(*args, **kwargs)
else:
return super().forward(*args, **kwargs)
# ==============================================================================
# Mixed Precision Operations

View File

@ -428,7 +428,7 @@ class CLIP:
def generate(self, tokens, do_sample=True, max_length=256, temperature=1.0, top_k=50, top_p=0.95, min_p=0.0, repetition_penalty=1.0, seed=None):
self.cond_stage_model.reset_clip_options()
self.load_model()
self.load_model(tokens)
self.cond_stage_model.set_clip_options({"layer": None})
self.cond_stage_model.set_clip_options({"execution_device": self.patcher.load_device})
return self.cond_stage_model.generate(tokens, do_sample=do_sample, max_length=max_length, temperature=temperature, top_k=top_k, top_p=top_p, min_p=min_p, repetition_penalty=repetition_penalty, seed=seed)
@ -954,7 +954,8 @@ class VAE:
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
pixel_samples[x:x+batch_number] = out
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
@ -1029,7 +1030,8 @@ class VAE:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
samples[x:x + batch_number] = out
except model_management.OOM_EXCEPTION:
except Exception as e:
model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE encoding, retrying with tiled VAE encoding.")
#NOTE: We don't know what tensors were allocated to stack variables at the time of the
#exception and the exception itself refs them all until we get out of this except block.
@ -1467,7 +1469,7 @@ def load_text_encoder_state_dicts(state_dicts=[], embedding_directory=None, clip
clip_target.clip = comfy.text_encoders.kandinsky5.te(**llama_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.kandinsky5.Kandinsky5TokenizerImage
elif clip_type == CLIPType.LTXV:
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data))
clip_target.clip = comfy.text_encoders.lt.ltxav_te(**llama_detect(clip_data), **comfy.text_encoders.lt.sd_detect(clip_data))
clip_target.tokenizer = comfy.text_encoders.lt.LTXAVGemmaTokenizer
tokenizer_data["spiece_model"] = clip_data[0].get("spiece_model", None)
elif clip_type == CLIPType.NEWBIE:

View File

@ -1118,6 +1118,20 @@ class ZImage(Lumina2):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen3_4b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.z_image.ZImageTokenizer, comfy.text_encoders.z_image.te(**hunyuan_detect))
class ZImagePixelSpace(ZImage):
unet_config = {
"image_model": "zimage_pixel",
}
# Pixel-space model: no spatial compression, operates on raw RGB patches.
latent_format = latent_formats.ZImagePixelSpace
# Much lower memory than latent-space models (no VAE, small patches).
memory_usage_factor = 0.03 # TODO: figure out the optimal value for this.
def get_model(self, state_dict, prefix="", device=None):
return model_base.ZImagePixelSpace(self, device=device)
class WAN21_T2V(supported_models_base.BASE):
unet_config = {
"image_model": "wan2.1",
@ -1720,6 +1734,6 @@ class LongCatImage(supported_models_base.BASE):
hunyuan_detect = comfy.text_encoders.hunyuan_video.llama_detect(state_dict, "{}qwen25_7b.transformer.".format(pref))
return supported_models_base.ClipTarget(comfy.text_encoders.longcat_image.LongCatImageTokenizer, comfy.text_encoders.longcat_image.te(**hunyuan_detect))
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models = [LotusD, Stable_Zero123, SD15_instructpix2pix, SD15, SD20, SD21UnclipL, SD21UnclipH, SDXL_instructpix2pix, SDXLRefiner, SDXL, SSD1B, KOALA_700M, KOALA_1B, Segmind_Vega, SD_X4Upscaler, Stable_Cascade_C, Stable_Cascade_B, SV3D_u, SV3D_p, SD3, StableAudio, AuraFlow, PixArtAlpha, PixArtSigma, HunyuanDiT, HunyuanDiT1, FluxInpaint, Flux, LongCatImage, FluxSchnell, GenmoMochi, LTXV, LTXAV, HunyuanVideo15_SR_Distilled, HunyuanVideo15, HunyuanImage21Refiner, HunyuanImage21, HunyuanVideoSkyreelsI2V, HunyuanVideoI2V, HunyuanVideo, CosmosT2V, CosmosI2V, CosmosT2IPredict2, CosmosI2VPredict2, ZImagePixelSpace, ZImage, Lumina2, WAN22_T2V, WAN21_T2V, WAN21_I2V, WAN21_FunControl2V, WAN21_Vace, WAN21_Camera, WAN22_Camera, WAN22_S2V, WAN21_HuMo, WAN22_Animate, WAN21_FlowRVS, WAN21_SCAIL, Hunyuan3Dv2mini, Hunyuan3Dv2, Hunyuan3Dv2_1, HiDream, Chroma, ChromaRadiance, ACEStep, ACEStep15, Omnigen2, QwenImage, Flux2, Kandinsky5Image, Kandinsky5, Anima]
models += [SVD_img2vid]

View File

@ -97,18 +97,39 @@ class Gemma3_12BModel(sd1_clip.SDClipModel):
comfy.utils.normalize_image_embeddings(embeds, embeds_info, self.transformer.model.config.hidden_size ** 0.5)
return self.transformer.generate(embeds, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed, stop_tokens=[106]) # 106 is <end_of_turn>
class DualLinearProjection(torch.nn.Module):
def __init__(self, in_dim, out_dim_video, out_dim_audio, dtype=None, device=None, operations=None):
super().__init__()
self.audio_aggregate_embed = operations.Linear(in_dim, out_dim_audio, bias=True, dtype=dtype, device=device)
self.video_aggregate_embed = operations.Linear(in_dim, out_dim_video, bias=True, dtype=dtype, device=device)
def forward(self, x):
source_dim = x.shape[-1]
x = x.movedim(1, -1)
x = (x * torch.rsqrt(torch.mean(x**2, dim=2, keepdim=True) + 1e-6)).flatten(start_dim=2)
video = self.video_aggregate_embed(x * math.sqrt(self.video_aggregate_embed.out_features / source_dim))
audio = self.audio_aggregate_embed(x * math.sqrt(self.audio_aggregate_embed.out_features / source_dim))
return torch.cat((video, audio), dim=-1)
class LTXAVTEModel(torch.nn.Module):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, model_options={}):
def __init__(self, dtype_llama=None, device="cpu", dtype=None, text_projection_type="single_linear", model_options={}):
super().__init__()
self.dtypes = set()
self.dtypes.add(dtype)
self.compat_mode = False
self.text_projection_type = text_projection_type
self.gemma3_12b = Gemma3_12BModel(device=device, dtype=dtype_llama, model_options=model_options, layer="all", layer_idx=None)
self.dtypes.add(dtype_llama)
operations = self.gemma3_12b.operations # TODO
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
if self.text_projection_type == "single_linear":
self.text_embedding_projection = operations.Linear(3840 * 49, 3840, bias=False, dtype=dtype, device=device)
elif self.text_projection_type == "dual_linear":
self.text_embedding_projection = DualLinearProjection(3840 * 49, 4096, 2048, dtype=dtype, device=device, operations=operations)
def enable_compat_mode(self): # TODO: remove
from comfy.ldm.lightricks.embeddings_connector import Embeddings1DConnector
@ -148,18 +169,25 @@ class LTXAVTEModel(torch.nn.Module):
out_device = out.device
if comfy.model_management.should_use_bf16(self.execution_device):
out = out.to(device=self.execution_device, dtype=torch.bfloat16)
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
out = out.float()
if self.compat_mode:
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
if self.text_projection_type == "single_linear":
out = out.movedim(1, -1).to(self.execution_device)
out = 8.0 * (out - out.mean(dim=(1, 2), keepdim=True)) / (out.amax(dim=(1, 2), keepdim=True) - out.amin(dim=(1, 2), keepdim=True) + 1e-6)
out = out.reshape((out.shape[0], out.shape[1], -1))
out = self.text_embedding_projection(out)
return out.to(out_device), pooled
if self.compat_mode:
out_vid = self.video_embeddings_connector(out)[0]
out_audio = self.audio_embeddings_connector(out)[0]
out = torch.concat((out_vid, out_audio), dim=-1)
extra = {}
else:
extra = {"unprocessed_ltxav_embeds": True}
elif self.text_projection_type == "dual_linear":
out = self.text_embedding_projection(out)
extra = {"unprocessed_ltxav_embeds": True}
return out.to(device=out_device, dtype=torch.float), pooled, extra
def generate(self, tokens, do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed):
return self.gemma3_12b.generate(tokens["gemma3_12b"], do_sample, max_length, temperature, top_k, top_p, min_p, repetition_penalty, seed)
@ -168,7 +196,7 @@ class LTXAVTEModel(torch.nn.Module):
if "model.layers.47.self_attn.q_norm.weight" in sd:
return self.gemma3_12b.load_sd(sd)
else:
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight"}, filter_keys=True)
sdo = comfy.utils.state_dict_prefix_replace(sd, {"text_embedding_projection.aggregate_embed.weight": "text_embedding_projection.weight", "text_embedding_projection.": "text_embedding_projection."}, filter_keys=True)
if len(sdo) == 0:
sdo = sd
@ -206,7 +234,7 @@ class LTXAVTEModel(torch.nn.Module):
num_tokens = max(num_tokens, 642)
return num_tokens * constant * 1024 * 1024
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
def ltxav_te(dtype_llama=None, llama_quantization_metadata=None, text_projection_type="single_linear"):
class LTXAVTEModel_(LTXAVTEModel):
def __init__(self, device="cpu", dtype=None, model_options={}):
if llama_quantization_metadata is not None:
@ -214,9 +242,19 @@ def ltxav_te(dtype_llama=None, llama_quantization_metadata=None):
model_options["llama_quantization_metadata"] = llama_quantization_metadata
if dtype_llama is not None:
dtype = dtype_llama
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, model_options=model_options)
super().__init__(dtype_llama=dtype_llama, device=device, dtype=dtype, text_projection_type=text_projection_type, model_options=model_options)
return LTXAVTEModel_
def sd_detect(state_dict_list, prefix=""):
for sd in state_dict_list:
if "{}text_embedding_projection.audio_aggregate_embed.bias".format(prefix) in sd:
return {"text_projection_type": "dual_linear"}
if "{}text_embedding_projection.weight".format(prefix) in sd or "{}text_embedding_projection.aggregate_embed.weight".format(prefix) in sd:
return {"text_projection_type": "single_linear"}
return {}
def gemma3_te(dtype_llama=None, llama_quantization_metadata=None):
class Gemma3_12BModel_(Gemma3_12BModel):
def __init__(self, device="cpu", dtype=None, model_options={}):

View File

@ -869,20 +869,31 @@ def safetensors_header(safetensors_path, max_size=100*1024*1024):
ATTR_UNSET={}
def set_attr(obj, attr, value):
def resolve_attr(obj, attr):
attrs = attr.split(".")
for name in attrs[:-1]:
obj = getattr(obj, name)
prev = getattr(obj, attrs[-1], ATTR_UNSET)
return obj, attrs[-1]
def set_attr(obj, attr, value):
obj, name = resolve_attr(obj, attr)
prev = getattr(obj, name, ATTR_UNSET)
if value is ATTR_UNSET:
delattr(obj, attrs[-1])
delattr(obj, name)
else:
setattr(obj, attrs[-1], value)
setattr(obj, name, value)
return prev
def set_attr_param(obj, attr, value):
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value):
obj, name = resolve_attr(obj, attr)
prev = getattr(obj, name, ATTR_UNSET)
persistent = name not in getattr(obj, "_non_persistent_buffers_set", set())
obj.register_buffer(name, value, persistent=persistent)
return prev
def copy_to_param(obj, attr, value):
# inplace update tensor instead of replacing it
attrs = attr.split(".")

View File

@ -15,6 +15,7 @@ SERVER_FEATURE_FLAGS: dict[str, Any] = {
"max_upload_size": args.max_upload_size * 1024 * 1024, # Convert MB to bytes
"extension": {"manager": {"supports_v4": True}},
"node_replacements": True,
"assets": args.enable_assets,
}

View File

@ -401,6 +401,7 @@ class VideoFromComponents(VideoInput):
codec: VideoCodec = VideoCodec.AUTO,
metadata: Optional[dict] = None,
):
"""Save the video to a file path or BytesIO buffer."""
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
raise ValueError("Only MP4 format is supported for now")
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
@ -408,6 +409,10 @@ class VideoFromComponents(VideoInput):
extra_kwargs = {}
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
extra_kwargs["format"] = format.value
elif isinstance(path, io.BytesIO):
# BytesIO has no file extension, so av.open can't infer the format.
# Default to mp4 since that's the only supported format anyway.
extra_kwargs["format"] = "mp4"
with av.open(path, mode='w', options={'movflags': 'use_metadata_tags'}, **extra_kwargs) as output:
# Add metadata before writing any streams
if metadata is not None:

View File

@ -27,7 +27,7 @@ if TYPE_CHECKING:
from comfy_api.internal import (_ComfyNodeInternal, _NodeOutputInternal, classproperty, copy_class, first_real_override, is_class,
prune_dict, shallow_clone_class)
from comfy_execution.graph_utils import ExecutionBlocker
from ._util import MESH, VOXEL, SVG as _SVG, File3D
from ._util import MESH, VOXEL, SVG as _SVG, File3D, PLY as _PLY, NPZ as _NPZ
class FolderType(str, Enum):
@ -678,6 +678,16 @@ class Mesh(ComfyTypeIO):
Type = MESH
@comfytype(io_type="PLY")
class Ply(ComfyTypeIO):
Type = _PLY
@comfytype(io_type="NPZ")
class Npz(ComfyTypeIO):
Type = _NPZ
@comfytype(io_type="FILE_3D")
class File3DAny(ComfyTypeIO):
"""General 3D file type - accepts any supported 3D format."""
@ -1240,6 +1250,19 @@ class BoundingBox(ComfyTypeIO):
return d
@comfytype(io_type="CURVE")
class Curve(ComfyTypeIO):
CurvePoint = tuple[float, float]
Type = list[CurvePoint]
class Input(WidgetInput):
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None,
socketless: bool=True, default: list[tuple[float, float]]=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, None, default, socketless, None, None, None, None, advanced)
if default is None:
self.default = [(0.0, 0.0), (1.0, 1.0)]
DYNAMIC_INPUT_LOOKUP: dict[str, Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]] = {}
def register_dynamic_input_func(io_type: str, func: Callable[[dict[str, Any], dict[str, Any], tuple[str, dict[str, Any]], str, list[str] | None], None]):
DYNAMIC_INPUT_LOOKUP[io_type] = func
@ -2184,6 +2207,8 @@ __all__ = [
"LossMap",
"Voxel",
"Mesh",
"Ply",
"Npz",
"File3DAny",
"File3DGLB",
"File3DGLTF",
@ -2226,5 +2251,6 @@ __all__ = [
"PriceBadgeDepends",
"PriceBadge",
"BoundingBox",
"Curve",
"NodeReplace",
]

View File

@ -1,6 +1,8 @@
from .video_types import VideoContainer, VideoCodec, VideoComponents
from .geometry_types import VOXEL, MESH, File3D
from .image_types import SVG
from .ply_types import PLY
from .npz_types import NPZ
__all__ = [
# Utility Types
@ -11,4 +13,6 @@ __all__ = [
"MESH",
"File3D",
"SVG",
"PLY",
"NPZ",
]

View File

@ -0,0 +1,27 @@
from __future__ import annotations
import os
class NPZ:
"""Ordered collection of NPZ file payloads.
Each entry in ``frames`` is a complete compressed ``.npz`` file stored
as raw bytes (produced by ``numpy.savez_compressed`` into a BytesIO).
``save_to`` writes numbered files into a directory.
"""
def __init__(self, frames: list[bytes]) -> None:
self.frames = frames
@property
def num_frames(self) -> int:
return len(self.frames)
def save_to(self, directory: str, prefix: str = "frame") -> str:
os.makedirs(directory, exist_ok=True)
for i, frame_bytes in enumerate(self.frames):
path = os.path.join(directory, f"{prefix}_{i:06d}.npz")
with open(path, "wb") as f:
f.write(frame_bytes)
return directory

View File

@ -0,0 +1,97 @@
from __future__ import annotations
import numpy as np
class PLY:
"""Point cloud payload for PLY file output.
Supports two schemas:
- Pointcloud: xyz positions with optional colors, confidence, view_id (ASCII format)
- Gaussian: raw binary PLY data built by producer nodes using plyfile (binary format)
When ``raw_data`` is provided, the object acts as an opaque binary PLY
carrier and ``save_to`` writes the bytes directly.
"""
def __init__(
self,
points: np.ndarray | None = None,
colors: np.ndarray | None = None,
confidence: np.ndarray | None = None,
view_id: np.ndarray | None = None,
raw_data: bytes | None = None,
) -> None:
self.raw_data = raw_data
if raw_data is not None:
self.points = None
self.colors = None
self.confidence = None
self.view_id = None
return
if points is None:
raise ValueError("Either points or raw_data must be provided")
if points.ndim != 2 or points.shape[1] != 3:
raise ValueError(f"points must be (N, 3), got {points.shape}")
self.points = np.ascontiguousarray(points, dtype=np.float32)
self.colors = np.ascontiguousarray(colors, dtype=np.float32) if colors is not None else None
self.confidence = np.ascontiguousarray(confidence, dtype=np.float32) if confidence is not None else None
self.view_id = np.ascontiguousarray(view_id, dtype=np.int32) if view_id is not None else None
@property
def is_gaussian(self) -> bool:
return self.raw_data is not None
@property
def num_points(self) -> int:
if self.points is not None:
return self.points.shape[0]
return 0
@staticmethod
def _to_numpy(arr, dtype):
if arr is None:
return None
if hasattr(arr, "numpy"):
arr = arr.cpu().numpy() if hasattr(arr, "cpu") else arr.numpy()
return np.ascontiguousarray(arr, dtype=dtype)
def save_to(self, path: str) -> str:
if self.raw_data is not None:
with open(path, "wb") as f:
f.write(self.raw_data)
return path
self.points = self._to_numpy(self.points, np.float32)
self.colors = self._to_numpy(self.colors, np.float32)
self.confidence = self._to_numpy(self.confidence, np.float32)
self.view_id = self._to_numpy(self.view_id, np.int32)
N = self.num_points
header_lines = [
"ply",
"format ascii 1.0",
f"element vertex {N}",
"property float x",
"property float y",
"property float z",
]
if self.colors is not None:
header_lines += ["property uchar red", "property uchar green", "property uchar blue"]
if self.confidence is not None:
header_lines.append("property float confidence")
if self.view_id is not None:
header_lines.append("property int view_id")
header_lines.append("end_header")
with open(path, "w") as f:
f.write("\n".join(header_lines) + "\n")
for i in range(N):
parts = [f"{self.points[i, 0]} {self.points[i, 1]} {self.points[i, 2]}"]
if self.colors is not None:
r, g, b = (self.colors[i] * 255).clip(0, 255).astype(np.uint8)
parts.append(f"{r} {g} {b}")
if self.confidence is not None:
parts.append(f"{self.confidence[i]}")
if self.view_id is not None:
parts.append(f"{int(self.view_id[i])}")
f.write(" ".join(parts) + "\n")
return path

View File

@ -7,7 +7,8 @@ class ImageGenerationRequest(BaseModel):
aspect_ratio: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_for: str = Field("url")
response_format: str = Field("url")
resolution: str = Field(...)
class InputUrlObject(BaseModel):
@ -16,12 +17,13 @@ class InputUrlObject(BaseModel):
class ImageEditRequest(BaseModel):
model: str = Field(...)
image: InputUrlObject = Field(...)
images: list[InputUrlObject] = Field(...)
prompt: str = Field(...)
resolution: str = Field(...)
n: int = Field(...)
seed: int = Field(...)
response_for: str = Field("url")
response_format: str = Field("url")
aspect_ratio: str | None = Field(...)
class VideoGenerationRequest(BaseModel):
@ -47,8 +49,13 @@ class ImageResponseObject(BaseModel):
revised_prompt: str | None = Field(None)
class UsageObject(BaseModel):
cost_in_usd_ticks: int | None = Field(None)
class ImageGenerationResponse(BaseModel):
data: list[ImageResponseObject] = Field(...)
usage: UsageObject | None = Field(None)
class VideoGenerationResponse(BaseModel):
@ -65,3 +72,4 @@ class VideoStatusResponse(BaseModel):
status: str | None = Field(None)
video: VideoResponseObject | None = Field(None)
model: str | None = Field(None)
usage: UsageObject | None = Field(None)

View File

@ -66,13 +66,17 @@ class To3DProTaskQueryRequest(BaseModel):
JobId: str = Field(...)
class To3DUVFileInput(BaseModel):
class TaskFile3DInput(BaseModel):
Type: str = Field(..., description="File type: GLB, OBJ, or FBX")
Url: str = Field(...)
class To3DUVTaskRequest(BaseModel):
File: To3DUVFileInput = Field(...)
File: TaskFile3DInput = Field(...)
class To3DPartTaskRequest(BaseModel):
File: TaskFile3DInput = Field(...)
class TextureEditImageInfo(BaseModel):
@ -80,7 +84,13 @@ class TextureEditImageInfo(BaseModel):
class TextureEditTaskRequest(BaseModel):
File3D: To3DUVFileInput = Field(...)
File3D: TaskFile3DInput = Field(...)
Image: TextureEditImageInfo | None = Field(None)
Prompt: str | None = Field(None)
EnablePBR: bool | None = Field(None)
class SmartTopologyRequest(BaseModel):
File3D: TaskFile3DInput = Field(...)
PolygonType: str | None = Field(...)
FaceLevel: str | None = Field(...)

View File

@ -148,3 +148,4 @@ class MotionControlRequest(BaseModel):
keep_original_sound: str = Field(...)
character_orientation: str = Field(...)
mode: str = Field(..., description="'pro' or 'std'")
model_name: str = Field(...)

View File

@ -0,0 +1,68 @@
from pydantic import BaseModel, Field
class RevePostprocessingOperation(BaseModel):
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
upscale_factor: int | None = Field(
None,
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
ge=2,
le=4,
)
class ReveImageCreateRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageEditRequest(BaseModel):
edit_instruction: str = Field(...)
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageRemixRequest(BaseModel):
prompt: str = Field(...)
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageResponse(BaseModel):
image: str | None = Field(None, description="The base64 encoded image data.")
request_id: str | None = Field(None, description="A unique id for the request.")
credits_used: float | None = Field(None, description="The number of credits used for this request.")
version: str | None = Field(None, description="The specific model version used.")
content_violation: bool | None = Field(
None, description="Indicates whether the generated image violates the content policy."
)

View File

@ -72,18 +72,6 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
)
class GeminiModel(str, Enum):
"""
Gemini Model Names allowed by comfy-api
"""
gemini_2_5_pro_preview_05_06 = "gemini-2.5-pro-preview-05-06"
gemini_2_5_flash_preview_04_17 = "gemini-2.5-flash-preview-04-17"
gemini_2_5_pro = "gemini-2.5-pro"
gemini_2_5_flash = "gemini-2.5-flash"
gemini_3_0_pro = "gemini-3-pro-preview"
class GeminiImageModel(str, Enum):
"""
Gemini Image Model Names allowed by comfy-api
@ -237,10 +225,14 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
input_tokens_price = 0.30
output_text_tokens_price = 2.50
output_image_tokens_price = 30.0
elif response.modelVersion == "gemini-3-pro-preview":
elif response.modelVersion in ("gemini-3-pro-preview", "gemini-3.1-pro-preview"):
input_tokens_price = 2
output_text_tokens_price = 12.0
output_image_tokens_price = 0.0
elif response.modelVersion == "gemini-3.1-flash-lite-preview":
input_tokens_price = 0.25
output_text_tokens_price = 1.50
output_image_tokens_price = 0.0
elif response.modelVersion == "gemini-3-pro-image-preview":
input_tokens_price = 2
output_text_tokens_price = 12.0
@ -292,8 +284,16 @@ class GeminiNode(IO.ComfyNode):
),
IO.Combo.Input(
"model",
options=GeminiModel,
default=GeminiModel.gemini_2_5_pro,
options=[
"gemini-2.5-pro-preview-05-06",
"gemini-2.5-flash-preview-04-17",
"gemini-2.5-pro",
"gemini-2.5-flash",
"gemini-3-pro-preview",
"gemini-3-1-pro",
"gemini-3-1-flash-lite",
],
default="gemini-3-1-pro",
tooltip="The Gemini model to use for generating responses.",
),
IO.Int.Input(
@ -363,11 +363,16 @@ class GeminiNode(IO.ComfyNode):
"usd": [0.00125, 0.01],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gemini-3-pro-preview") ? {
: ($contains($m, "gemini-3-pro-preview") or $contains($m, "gemini-3-1-pro")) ? {
"type": "list_usd",
"usd": [0.002, 0.012],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: $contains($m, "gemini-3-1-flash-lite") ? {
"type": "list_usd",
"usd": [0.00025, 0.0015],
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
}
: {"type":"text", "text":"Token-based"}
)
""",
@ -436,12 +441,14 @@ class GeminiNode(IO.ComfyNode):
files: list[GeminiPart] | None = None,
system_prompt: str = "",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False)
if model == "gemini-3-pro-preview":
model = "gemini-3.1-pro-preview" # model "gemini-3-pro-preview" will be soon deprecated by Google
elif model == "gemini-3-1-pro":
model = "gemini-3.1-pro-preview"
elif model == "gemini-3-1-flash-lite":
model = "gemini-3.1-flash-lite-preview"
# Create parts list with text prompt as the first part
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
# Add other modal parts
if images is not None:
parts.extend(await create_image_parts(cls, images))
if audio is not None:

View File

@ -27,6 +27,12 @@ from comfy_api_nodes.util import (
)
def _extract_grok_price(response) -> float | None:
if response.usage and response.usage.cost_in_usd_ticks is not None:
return response.usage.cost_in_usd_ticks / 10_000_000_000
return None
class GrokImageNode(IO.ComfyNode):
@classmethod
@ -37,7 +43,10 @@ class GrokImageNode(IO.ComfyNode):
category="api node/image/Grok",
description="Generate images using Grok based on a text prompt",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
IO.Combo.Input(
"model",
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
),
IO.String.Input(
"prompt",
multiline=True,
@ -81,6 +90,7 @@ class GrokImageNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Combo.Input("resolution", options=["1K", "2K"], optional=True),
],
outputs=[
IO.Image.Output(),
@ -92,8 +102,13 @@ class GrokImageNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
expr="""{"type":"usd","usd":0.033 * widgets.number_of_images}""",
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
expr="""
(
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
{"type":"usd","usd": $rate * widgets.number_of_images}
)
""",
),
)
@ -105,6 +120,7 @@ class GrokImageNode(IO.ComfyNode):
aspect_ratio: str,
number_of_images: int,
seed: int,
resolution: str = "1K",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
response = await sync_op(
@ -116,8 +132,10 @@ class GrokImageNode(IO.ComfyNode):
aspect_ratio=aspect_ratio,
n=number_of_images,
seed=seed,
resolution=resolution.lower(),
),
response_model=ImageGenerationResponse,
price_extractor=_extract_grok_price,
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
@ -138,14 +156,17 @@ class GrokImageEditNode(IO.ComfyNode):
category="api node/image/Grok",
description="Modify an existing image based on a text prompt",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-image-beta"]),
IO.Image.Input("image"),
IO.Combo.Input(
"model",
options=["grok-imagine-image-pro", "grok-imagine-image", "grok-imagine-image-beta"],
),
IO.Image.Input("image", display_name="images"),
IO.String.Input(
"prompt",
multiline=True,
tooltip="The text prompt used to generate the image",
),
IO.Combo.Input("resolution", options=["1K"]),
IO.Combo.Input("resolution", options=["1K", "2K"]),
IO.Int.Input(
"number_of_images",
default=1,
@ -166,6 +187,27 @@ class GrokImageEditNode(IO.ComfyNode):
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
IO.Combo.Input(
"aspect_ratio",
options=[
"auto",
"1:1",
"2:3",
"3:2",
"3:4",
"4:3",
"9:16",
"16:9",
"9:19.5",
"19.5:9",
"9:20",
"20:9",
"1:2",
"2:1",
],
optional=True,
tooltip="Only allowed when multiple images are connected to the image input.",
),
],
outputs=[
IO.Image.Output(),
@ -177,8 +219,13 @@ class GrokImageEditNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["number_of_images"]),
expr="""{"type":"usd","usd":0.002 + 0.033 * widgets.number_of_images}""",
depends_on=IO.PriceBadgeDepends(widgets=["model", "number_of_images"]),
expr="""
(
$rate := $contains(widgets.model, "pro") ? 0.07 : 0.02;
{"type":"usd","usd": 0.002 + $rate * widgets.number_of_images}
)
""",
),
)
@ -191,22 +238,32 @@ class GrokImageEditNode(IO.ComfyNode):
resolution: str,
number_of_images: int,
seed: int,
aspect_ratio: str = "auto",
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=True, min_length=1)
if get_number_of_images(image) != 1:
raise ValueError("Only one input image is supported.")
if model == "grok-imagine-image-pro":
if get_number_of_images(image) > 1:
raise ValueError("The pro model supports only 1 input image.")
elif get_number_of_images(image) > 3:
raise ValueError("A maximum of 3 input images is supported.")
if aspect_ratio != "auto" and get_number_of_images(image) == 1:
raise ValueError(
"Custom aspect ratio is only allowed when multiple images are connected to the image input."
)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/xai/v1/images/edits", method="POST"),
data=ImageEditRequest(
model=model,
image=InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(image)}"),
images=[InputUrlObject(url=f"data:image/png;base64,{tensor_to_base64_string(i)}") for i in image],
prompt=prompt,
resolution=resolution.lower(),
n=number_of_images,
seed=seed,
aspect_ratio=None if aspect_ratio == "auto" else aspect_ratio,
),
response_model=ImageGenerationResponse,
price_extractor=_extract_grok_price,
)
if len(response.data) == 1:
return IO.NodeOutput(await download_url_to_image_tensor(response.data[0].url))
@ -227,7 +284,7 @@ class GrokVideoNode(IO.ComfyNode):
category="api node/video/Grok",
description="Generate video from a prompt or an image",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.String.Input(
"prompt",
multiline=True,
@ -275,10 +332,11 @@ class GrokVideoNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(widgets=["duration"], inputs=["image"]),
depends_on=IO.PriceBadgeDepends(widgets=["duration", "resolution"], inputs=["image"]),
expr="""
(
$base := 0.181 * widgets.duration;
$rate := widgets.resolution = "720p" ? 0.07 : 0.05;
$base := $rate * widgets.duration;
{"type":"usd","usd": inputs.image.connected ? $base + 0.002 : $base}
)
""",
@ -321,6 +379,7 @@ class GrokVideoNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))
@ -335,7 +394,7 @@ class GrokVideoEditNode(IO.ComfyNode):
category="api node/video/Grok",
description="Edit an existing video based on a text prompt.",
inputs=[
IO.Combo.Input("model", options=["grok-imagine-video-beta"]),
IO.Combo.Input("model", options=["grok-imagine-video", "grok-imagine-video-beta"]),
IO.String.Input(
"prompt",
multiline=True,
@ -364,7 +423,7 @@ class GrokVideoEditNode(IO.ComfyNode):
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd": 0.191, "format": {"suffix": "/sec", "approximate": true}}""",
expr="""{"type":"usd","usd": 0.06, "format": {"suffix": "/sec", "approximate": true}}""",
),
)
@ -398,6 +457,7 @@ class GrokVideoEditNode(IO.ComfyNode):
ApiEndpoint(path=f"/proxy/xai/v1/videos/{initial_response.request_id}"),
status_extractor=lambda r: r.status if r.status is not None else "complete",
response_model=VideoStatusResponse,
price_extractor=_extract_grok_price,
)
return IO.NodeOutput(await download_url_to_video_output(response.video.url))

View File

@ -5,18 +5,19 @@ from comfy_api_nodes.apis.hunyuan3d import (
Hunyuan3DViewImage,
InputGenerateType,
ResultFile3D,
SmartTopologyRequest,
TaskFile3DInput,
TextureEditTaskRequest,
To3DPartTaskRequest,
To3DProTaskCreateResponse,
To3DProTaskQueryRequest,
To3DProTaskRequest,
To3DProTaskResultResponse,
To3DUVFileInput,
To3DUVTaskRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
download_url_to_file_3d,
download_url_to_image_tensor,
downscale_image_tensor_by_max_side,
poll_op,
sync_op,
@ -344,7 +345,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
outputs=[
IO.File3DOBJ.Output(display_name="OBJ"),
IO.File3DFBX.Output(display_name="FBX"),
IO.Image.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
@ -375,7 +375,7 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-uv", method="POST"),
response_model=To3DProTaskCreateResponse,
data=To3DUVTaskRequest(
File=To3DUVFileInput(
File=TaskFile3DInput(
Type=file_format.upper(),
Url=await upload_3d_model_to_comfyapi(cls, model_3d, file_format),
)
@ -394,7 +394,6 @@ class TencentModelTo3DUVNode(IO.ComfyNode):
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "image").Url),
)
@ -463,7 +462,7 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-texture-edit", method="POST"),
response_model=To3DProTaskCreateResponse,
data=TextureEditTaskRequest(
File3D=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
Prompt=prompt,
EnablePBR=True,
),
@ -538,8 +537,8 @@ class Tencent3DPartNode(IO.ComfyNode):
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-part", method="POST"),
response_model=To3DProTaskCreateResponse,
data=To3DUVTaskRequest(
File=To3DUVFileInput(Type=file_format.upper(), Url=model_url),
data=To3DPartTaskRequest(
File=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
),
is_rate_limited=_is_tencent_rate_limited,
)
@ -557,15 +556,107 @@ class Tencent3DPartNode(IO.ComfyNode):
)
class TencentSmartTopologyNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="TencentSmartTopologyNode",
display_name="Hunyuan3D: Smart Topology",
category="api node/3d/Tencent",
description="Perform smart retopology on a 3D model. "
"Supports GLB/OBJ formats; max 200MB; recommended for high-poly models.",
inputs=[
IO.MultiType.Input(
"model_3d",
types=[IO.File3DGLB, IO.File3DOBJ, IO.File3DAny],
tooltip="Input 3D model (GLB or OBJ)",
),
IO.Combo.Input(
"polygon_type",
options=["triangle", "quadrilateral"],
tooltip="Surface composition type.",
),
IO.Combo.Input(
"face_level",
options=["medium", "high", "low"],
tooltip="Polygon reduction level.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
display_mode=IO.NumberDisplay.number,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[
IO.File3DOBJ.Output(display_name="OBJ"),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(expr='{"type":"usd","usd":1.0}'),
)
SUPPORTED_FORMATS = {"glb", "obj"}
@classmethod
async def execute(
cls,
model_3d: Types.File3D,
polygon_type: str,
face_level: str,
seed: int,
) -> IO.NodeOutput:
_ = seed
file_format = model_3d.format.lower()
if file_format not in cls.SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported file format: '{file_format}'. " f"Supported: {', '.join(sorted(cls.SUPPORTED_FORMATS))}."
)
model_url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology", method="POST"),
response_model=To3DProTaskCreateResponse,
data=SmartTopologyRequest(
File3D=TaskFile3DInput(Type=file_format.upper(), Url=model_url),
PolygonType=polygon_type,
FaceLevel=face_level,
),
is_rate_limited=_is_tencent_rate_limited,
)
if response.Error:
raise ValueError(f"Task creation failed: [{response.Error.Code}] {response.Error.Message}")
result = await poll_op(
cls,
ApiEndpoint(path="/proxy/tencent/hunyuan/3d-smart-topology/query", method="POST"),
data=To3DProTaskQueryRequest(JobId=response.JobId),
response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status,
)
return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
)
class TencentHunyuan3DExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
TencentTextToModelNode,
TencentImageToModelNode,
# TencentModelTo3DUVNode,
TencentModelTo3DUVNode,
# Tencent3DTextureEditNode,
Tencent3DPartNode,
TencentSmartTopologyNode,
]

View File

@ -2747,6 +2747,7 @@ class MotionControl(IO.ComfyNode):
"but the character orientation matches the reference image (camera/other details via prompt).",
),
IO.Combo.Input("mode", options=["pro", "std"]),
IO.Combo.Input("model", options=["kling-v3", "kling-v2-6"], optional=True),
],
outputs=[
IO.Video.Output(),
@ -2777,6 +2778,7 @@ class MotionControl(IO.ComfyNode):
keep_original_sound: bool,
character_orientation: str,
mode: str,
model: str = "kling-v2-6",
) -> IO.NodeOutput:
validate_string(prompt, max_length=2500)
validate_image_dimensions(reference_image, min_width=340, min_height=340)
@ -2797,6 +2799,7 @@ class MotionControl(IO.ComfyNode):
keep_original_sound="yes" if keep_original_sound else "no",
character_orientation=character_orientation,
mode=mode,
model_name=model,
),
)
if response.code:

View File

@ -0,0 +1,395 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.reve import (
ReveImageCreateRequest,
ReveImageEditRequest,
ReveImageRemixRequest,
RevePostprocessingOperation,
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
sync_op_raw,
tensor_to_base64_string,
validate_string,
)
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
ops = []
if upscale["upscale"] == "enabled":
ops.append(
RevePostprocessingOperation(
process="upscale",
upscale_factor=upscale["upscale_factor"],
)
)
if remove_background:
ops.append(RevePostprocessingOperation(process="remove_background"))
return ops or None
def _postprocessing_inputs():
return [
IO.DynamicCombo.Input(
"upscale",
options=[
IO.DynamicCombo.Option("disabled", []),
IO.DynamicCombo.Option(
"enabled",
[
IO.Int.Input(
"upscale_factor",
default=2,
min=2,
max=4,
step=1,
tooltip="Upscale factor (2x, 3x, or 4x).",
),
],
),
],
tooltip="Upscale the generated image. May add additional cost.",
),
IO.Boolean.Input(
"remove_background",
default=False,
tooltip="Remove the background from the generated image. May add additional cost.",
),
]
def _reve_price_extractor(headers: dict) -> float | None:
credits_used = headers.get("x-reve-credits-used")
if credits_used is not None:
return float(credits_used) / 524.48
return None
def _reve_response_header_validator(headers: dict) -> None:
error_code = headers.get("x-reve-error-code")
if error_code:
raise ValueError(f"Reve API error: {error_code}")
if headers.get("x-reve-content-violation", "").lower() == "true":
raise ValueError("The generated image was flagged for content policy violation.")
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
return [
IO.DynamicCombo.Option(
version,
[
IO.Combo.Input(
"aspect_ratio",
options=aspect_ratios,
tooltip="Aspect ratio of the output image.",
),
IO.Int.Input(
"test_time_scaling",
default=1,
min=1,
max=5,
step=1,
tooltip="Higher values produce better images but cost more credits.",
advanced=True,
),
],
)
for version in versions
]
class ReveImageCreateNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageCreateNode",
display_name="Reve Image Create",
category="api node/image/Reve",
description="Generate images from text descriptions using Reve.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-create@20250915"],
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for generation.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/create",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageCreateRequest(
prompt=prompt,
aspect_ratio=model["aspect_ratio"],
version=model["model"],
test_time_scaling=model["test_time_scaling"],
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageEditNode",
display_name="Reve Image Edit",
category="api node/image/Reve",
description="Edit images using natural language instructions with Reve.",
inputs=[
IO.Image.Input("image", tooltip="The image to edit."),
IO.String.Input(
"edit_instruction",
multiline=True,
default="",
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-edit@20250915", "reve-edit-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for editing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
edit_instruction: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(edit_instruction, min_length=1, max_length=2560)
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/edit",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageEditRequest(
edit_instruction=edit_instruction,
reference_image=tensor_to_base64_string(image),
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageRemixNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageRemixNode",
display_name="Reve Image Remix",
category="api node/image/Reve",
description="Combine reference images with text prompts to create new images using Reve.",
inputs=[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="image_",
min=1,
max=6,
),
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. "
"May include XML img tags to reference specific images by index, "
"e.g. <img>0</img>, <img>1</img>, etc.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-remix@20250915", "reve-remix-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for remixing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
reference_images: IO.Autogrow.Type,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
if not reference_images:
raise ValueError("At least one reference image is required.")
ref_base64_list = []
for key in reference_images:
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
if len(ref_base64_list) > 6:
raise ValueError("Maximum 6 reference images are allowed.")
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/remix",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageRemixRequest(
prompt=prompt,
reference_images=ref_base64_list,
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ReveImageCreateNode,
ReveImageEditNode,
ReveImageRemixNode,
]
async def comfy_entrypoint() -> ReveExtension:
return ReveExtension()

View File

@ -67,6 +67,7 @@ class _RequestConfig:
progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None
response_header_validator: Callable[[dict[str, str]], None] | None = None
@dataclass
@ -83,7 +84,7 @@ class _PollUIState:
_RETRY_STATUS = {408, 500, 502, 503, 504} # status 429 is handled separately
COMPLETED_STATUSES = ["succeeded", "succeed", "success", "completed", "finished", "done", "complete"]
FAILED_STATUSES = ["cancelled", "canceled", "canceling", "fail", "failed", "error"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing"]
QUEUED_STATUSES = ["created", "queued", "queueing", "submitted", "initializing", "wait"]
async def sync_op(
@ -202,11 +203,13 @@ async def sync_op_raw(
monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None,
response_header_validator: Callable[[dict[str, str]], None] | None = None,
) -> dict[str, Any] | bytes:
"""
Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
- If as_binary=True: returns bytes.
- response_header_validator: optional callback receiving response headers dict
"""
if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True)
@ -232,6 +235,7 @@ async def sync_op_raw(
price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited,
response_header_validator=response_header_validator,
)
return await _request_base(cfg, expect_binary=as_binary)
@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
)
bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
if cfg.price_extractor:
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(resp_headers)
if cfg.response_header_validator:
cfg.response_header_validator(resp_headers)
operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response(
@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_method=method,
request_url=url,
response_status_code=resp.status,
response_headers=dict(resp.headers),
response_headers=resp_headers,
response_content=bytes_payload,
)
return bytes_payload

View File

@ -96,7 +96,7 @@ class VAEEncodeAudio(IO.ComfyNode):
def vae_decode_audio(vae, samples, tile=None, overlap=None):
if tile is not None:
audio = vae.decode_tiled(samples["samples"], tile_y=tile, overlap=overlap).movedim(-1, 1)
audio = vae.decode_tiled(samples["samples"], tile_x=tile, tile_y=tile, overlap=overlap).movedim(-1, 1)
else:
audio = vae.decode(samples["samples"]).movedim(-1, 1)

View File

@ -253,10 +253,12 @@ class LTXVAddGuide(io.ComfyNode):
return frame_idx, latent_idx
@classmethod
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1):
def add_keyframe_index(cls, cond, frame_idx, guiding_latent, scale_factors, latent_downscale_factor=1, causal_fix=None):
keyframe_idxs, _ = get_keyframe_idxs(cond)
_, latent_coords = cls.PATCHIFIER.patchify(guiding_latent)
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=frame_idx == 0) # we need the causal fix only if we're placing the new latents at index 0
if causal_fix is None:
causal_fix = frame_idx == 0 or guiding_latent.shape[2] == 1
pixel_coords = latent_to_pixel_coords(latent_coords, scale_factors, causal_fix=causal_fix)
pixel_coords[:, 0] += frame_idx
# The following adjusts keyframe end positions for small grid IC-LoRA.
@ -278,12 +280,12 @@ class LTXVAddGuide(io.ComfyNode):
return node_helpers.conditioning_set_values(cond, {"keyframe_idxs": keyframe_idxs})
@classmethod
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1):
def append_keyframe(cls, positive, negative, frame_idx, latent_image, noise_mask, guiding_latent, strength, scale_factors, guide_mask=None, in_channels=128, latent_downscale_factor=1, causal_fix=None):
if latent_image.shape[1] != in_channels or guiding_latent.shape[1] != in_channels:
raise ValueError("Adding guide to a combined AV latent is not supported.")
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor)
positive = cls.add_keyframe_index(positive, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix)
negative = cls.add_keyframe_index(negative, frame_idx, guiding_latent, scale_factors, latent_downscale_factor, causal_fix=causal_fix)
if guide_mask is not None:
target_h = max(noise_mask.shape[3], guide_mask.shape[3])

119
comfy_extras/nodes_math.py Normal file
View File

@ -0,0 +1,119 @@
"""Math expression node using simpleeval for safe evaluation.
Provides a ComfyMathExpression node that evaluates math expressions
against dynamically-grown numeric inputs.
"""
from __future__ import annotations
import math
import string
from simpleeval import simple_eval
from typing_extensions import override
from comfy_api.latest import ComfyExtension, io
MAX_EXPONENT = 4000
def _variadic_sum(*args):
"""Support both sum(values) and sum(a, b, c)."""
if len(args) == 1 and hasattr(args[0], "__iter__"):
return sum(args[0])
return sum(args)
def _safe_pow(base, exp):
"""Wrap pow() with an exponent cap to prevent DoS via huge exponents.
The ** operator is already guarded by simpleeval's safe_power, but
pow() as a callable bypasses that guard.
"""
if abs(exp) > MAX_EXPONENT:
raise ValueError(f"Exponent {exp} exceeds maximum allowed ({MAX_EXPONENT})")
return pow(base, exp)
MATH_FUNCTIONS = {
"sum": _variadic_sum,
"min": min,
"max": max,
"abs": abs,
"round": round,
"pow": _safe_pow,
"sqrt": math.sqrt,
"ceil": math.ceil,
"floor": math.floor,
"log": math.log,
"log2": math.log2,
"log10": math.log10,
"sin": math.sin,
"cos": math.cos,
"tan": math.tan,
"int": int,
"float": float,
}
class MathExpressionNode(io.ComfyNode):
"""Evaluates a math expression against dynamically-grown inputs."""
@classmethod
def define_schema(cls) -> io.Schema:
autogrow = io.Autogrow.TemplateNames(
input=io.MultiType.Input("value", [io.Float, io.Int]),
names=list(string.ascii_lowercase),
min=1,
)
return io.Schema(
node_id="ComfyMathExpression",
display_name="Math Expression",
category="math",
search_aliases=[
"expression", "formula", "calculate", "calculator",
"eval", "math",
],
inputs=[
io.String.Input("expression", default="a + b", multiline=True),
io.Autogrow.Input("values", template=autogrow),
],
outputs=[
io.Float.Output(display_name="FLOAT"),
io.Int.Output(display_name="INT"),
],
)
@classmethod
def execute(
cls, expression: str, values: io.Autogrow.Type
) -> io.NodeOutput:
if not expression.strip():
raise ValueError("Expression cannot be empty.")
context: dict = dict(values)
context["values"] = list(values.values())
result = simple_eval(expression, names=context, functions=MATH_FUNCTIONS)
# bool check must come first because bool is a subclass of int in Python
if isinstance(result, bool) or not isinstance(result, (int, float)):
raise ValueError(
f"Math Expression '{expression}' must evaluate to a numeric result, "
f"got {type(result).__name__}: {result!r}"
)
if not math.isfinite(result):
raise ValueError(
f"Math Expression '{expression}' produced a non-finite result: {result}"
)
return io.NodeOutput(float(result), int(result))
class MathExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[io.ComfyNode]]:
return [MathExpressionNode]
async def comfy_entrypoint() -> MathExtension:
return MathExtension()

View File

@ -0,0 +1,40 @@
import os
import folder_paths
from comfy_api.latest import io
from comfy_api.latest._util.npz_types import NPZ
class SaveNPZ(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SaveNPZ",
display_name="Save NPZ",
category="3d",
is_output_node=True,
inputs=[
io.Npz.Input("npz"),
io.String.Input("filename_prefix", default="da3_streaming/ComfyUI"),
],
)
@classmethod
def execute(cls, npz: NPZ, filename_prefix: str) -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, folder_paths.get_output_directory()
)
batch_dir = os.path.join(full_output_folder, f"{filename}_{counter:05}")
os.makedirs(batch_dir, exist_ok=True)
filenames = []
for i, frame_bytes in enumerate(npz.frames):
f = f"frame_{i:06d}.npz"
with open(os.path.join(batch_dir, f), "wb") as fh:
fh.write(frame_bytes)
filenames.append(f)
return io.NodeOutput(ui={"npz_files": [{"folder": os.path.join(subfolder, f"{filename}_{counter:05}"), "count": len(filenames), "type": "output"}]})
NODE_CLASS_MAPPINGS = {
"SaveNPZ": SaveNPZ,
}

View File

@ -0,0 +1,34 @@
import os
import folder_paths
from comfy_api.latest import io
from comfy_api.latest._util.ply_types import PLY
class SavePLY(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="SavePLY",
display_name="Save PLY",
category="3d",
is_output_node=True,
inputs=[
io.Ply.Input("ply"),
io.String.Input("filename_prefix", default="pointcloud/ComfyUI"),
],
)
@classmethod
def execute(cls, ply: PLY, filename_prefix: str) -> io.NodeOutput:
full_output_folder, filename, counter, subfolder, _ = folder_paths.get_save_image_path(
filename_prefix, folder_paths.get_output_directory()
)
f = f"{filename}_{counter:05}_.ply"
ply.save_to(os.path.join(full_output_folder, f))
return io.NodeOutput(ui={"pointclouds": [{"filename": f, "subfolder": subfolder, "type": "output"}]})
NODE_CLASS_MAPPINGS = {
"SavePLY": SavePLY,
}

View File

@ -86,7 +86,8 @@ class ImageUpscaleWithModel(io.ComfyNode):
pbar = comfy.utils.ProgressBar(steps)
s = comfy.utils.tiled_scale(in_img, lambda a: upscale_model(a), tile_x=tile, tile_y=tile, overlap=overlap, upscale_amount=upscale_model.scale, pbar=pbar)
oom = False
except model_management.OOM_EXCEPTION as e:
except Exception as e:
model_management.raise_non_oom(e)
tile //= 2
if tile < 128:
raise e

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is
# updated in pyproject.toml.
__version__ = "0.15.1"
__version__ = "0.16.4"

View File

@ -645,7 +645,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
logging.error(traceback.format_exc())
tips = ""
if isinstance(ex, comfy.model_management.OOM_EXCEPTION):
if comfy.model_management.is_oom(ex):
tips = "This error means you ran out of memory on your GPU.\n\nTIPS: If the workflow worked before you might have accidentally set the batch_size to a large number."
logging.info("Memory summary: {}".format(comfy.model_management.debug_memory_summary()))
logging.error("Got an OOM, unloading all loaded models.")
@ -993,12 +993,14 @@ async def validate_inputs(prompt_id, prompt, item, validated):
continue
else:
try:
# Unwraps values wrapped in __value__ key. This is used to pass
# list widget value to execution, as by default list value is
# reserved to represent the connection between nodes.
if isinstance(val, dict) and "__value__" in val:
val = val["__value__"]
inputs[x] = val
# Unwraps values wrapped in __value__ key or typed wrapper.
# This is used to pass list widget values to execution,
# as by default list value is reserved to represent the
# connection between nodes.
if isinstance(val, dict):
if "__value__" in val:
val = val["__value__"]
inputs[x] = val
if input_type == "INT":
val = int(val)

68
main.py
View File

@ -17,13 +17,22 @@ import comfy.options
comfy.options.enable_args_parsing()
import importlib.util
import shutil
import importlib.metadata
import folder_paths
import time
from comfy.cli_args import args, enables_dynamic_vram
from app.logger import setup_logger
from app.assets.scanner import seed_assets
import itertools
import utils.extra_config # noqa: F401
from utils.mime_types import init_mime_types
import faulthandler
import logging
import sys
from comfy_execution.progress import get_progress_state
from comfy_execution.utils import get_executing_context
from comfy_api import feature_flags
from app.database.db import init_db, dependencies_available
import comfy_aimdo.control
@ -65,6 +74,13 @@ if IS_PRIMARY_PROCESS:
if not IS_PYISOLATE_CHILD:
setup_logger(log_level=args.verbose, use_stdout=args.log_stdout)
faulthandler.enable(file=sys.stderr, all_threads=False)
import comfy_aimdo.control
if enables_dynamic_vram():
comfy_aimdo.control.init()
if os.name == "nt":
os.environ['MIMALLOC_PURGE_DELAY'] = '0'
@ -99,8 +115,15 @@ if __name__ == "__main__":
def handle_comfyui_manager_unavailable():
if not args.windows_standalone_build:
logging.warning(f"\n\nYou appear to be running comfyui-manager from source, this is not recommended. Please install comfyui-manager using the following command:\ncommand:\n\t{sys.executable} -m pip install --pre comfyui_manager\n")
manager_req_path = os.path.join(os.path.dirname(os.path.abspath(folder_paths.__file__)), "manager_requirements.txt")
uv_available = shutil.which("uv") is not None
pip_cmd = f"{sys.executable} -m pip install -r {manager_req_path}"
msg = f"\n\nTo use the `--enable-manager` feature, the `comfyui-manager` package must be installed first.\ncommand:\n\t{pip_cmd}"
if uv_available:
msg += f"\nor using uv:\n\tuv pip install -r {manager_req_path}"
msg += "\n"
logging.warning(msg)
args.enable_manager = False
@ -200,6 +223,7 @@ def execute_prestartup_script():
if not IS_PYISOLATE_CHILD:
apply_custom_paths()
init_mime_types()
if args.enable_manager and not IS_PYISOLATE_CHILD:
comfyui_manager.prestartup()
@ -210,7 +234,6 @@ if not IS_PYISOLATE_CHILD:
# Main code
import asyncio
import shutil
import threading
import gc
@ -218,6 +241,7 @@ if 'torch' in sys.modules:
logging.warning("WARNING: Potential Error in code: Torch already imported, torch should never be imported before this point.")
import comfy.utils
from app.assets.seeder import asset_seeder
if not IS_PYISOLATE_CHILD:
import execution
@ -298,6 +322,7 @@ def prompt_worker(q, server_instance):
for k in sensitive:
extra_data[k] = sensitive[k]
asset_seeder.pause()
e.execute(item[2], prompt_id, extra_data, item[4])
need_gc = True
@ -342,6 +367,7 @@ def prompt_worker(q, server_instance):
last_gc_collect = current_time
need_gc = False
hook_breaker_ac10a0.restore_functions()
asset_seeder.resume()
async def run(server_instance, address='', port=8188, verbose=True, call_on_start=None):
@ -392,12 +418,29 @@ def cleanup_temp():
def setup_database():
try:
from app.database.db import init_db, dependencies_available
if dependencies_available():
init_db()
if not args.disable_assets_autoscan:
seed_assets(["models"], enable_logging=True)
if args.enable_assets:
if asset_seeder.start(roots=("models", "input", "output"), prune_first=True, compute_hashes=True):
logging.info("Background asset scan initiated for models, input, output")
except Exception as e:
if "database is locked" in str(e):
logging.error(
"Database is locked. Another ComfyUI process is already using this database.\n"
"To resolve this, specify a separate database file for this instance:\n"
" --database-url sqlite:///path/to/another.db"
)
sys.exit(1)
if args.enable_assets:
logging.error(
f"Failed to initialize database: {e}\n"
"The --enable-assets flag requires a working database connection.\n"
"To resolve this, try one of the following:\n"
" 1. Install the latest requirements: pip install -r requirements.txt\n"
" 2. Specify an alternative database URL: --database-url sqlite:///path/to/your.db\n"
" 3. Use an in-memory database: --database-url sqlite:///:memory:"
)
sys.exit(1)
logging.error(f"Failed to initialize database. Please ensure you have installed the latest requirements. If the error persists, please report this as in future the database will be required: {e}")
@ -473,8 +516,12 @@ if __name__ == "__main__":
# Running directly, just start ComfyUI.
logging.info("Python version: {}".format(sys.version))
if not IS_PYISOLATE_CHILD:
import comfyui_version
logging.info("ComfyUI version: {}".format(comfyui_version.__version__))
for package in ("comfy-aimdo", "comfy-kitchen"):
try:
logging.info("{} version: {}".format(package, importlib.metadata.version(package)))
except:
pass
if sys.version_info.major == 3 and sys.version_info.minor < 10:
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
@ -486,5 +533,6 @@ if __name__ == "__main__":
event_loop.run_until_complete(x)
except KeyboardInterrupt:
logging.info("\nStopped server")
cleanup_temp()
finally:
asset_seeder.shutdown()
cleanup_temp()

View File

@ -1 +1 @@
comfyui_manager==4.1b1
comfyui_manager==4.1b2

View File

@ -2455,6 +2455,8 @@ async def init_builtin_extra_nodes():
"nodes_wan.py",
"nodes_lotus.py",
"nodes_hunyuan3d.py",
"nodes_save_ply.py",
"nodes_save_npz.py",
"nodes_primitive.py",
"nodes_cfg.py",
"nodes_optimalsteps.py",
@ -2486,6 +2488,8 @@ async def init_builtin_extra_nodes():
"nodes_toolkit.py",
"nodes_replacements.py",
"nodes_nag.py",
"nodes_sdpose.py",
"nodes_math.py",
]
import_failed = []

View File

@ -1,6 +1,6 @@
[project]
name = "ComfyUI"
version = "0.15.1"
version = "0.16.4"
readme = "README.md"
license = { file = "LICENSE" }
requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.19
comfyui-workflow-templates==0.9.5
comfyui-frontend-package==1.41.16
comfyui-workflow-templates==0.9.18
comfyui-embedded-docs==0.4.3
torch
torchsde
@ -20,10 +20,13 @@ tqdm
psutil
alembic
SQLAlchemy
filelock
av>=14.2.0
comfy-kitchen>=0.2.7
comfy-aimdo>=0.2.4
comfy-aimdo>=0.2.10
requests
simpleeval>=1.0.0
blake3
#non essential dependencies:
kornia>=0.7.1
@ -33,4 +36,4 @@ pydantic-settings~=2.0
PyOpenGL
glfw
pyisolate==0.9.1
pyisolate==0.9.2

View File

@ -32,8 +32,8 @@ import node_helpers
from comfyui_version import __version__
from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal
from app.assets.scanner import seed_assets
from app.assets.api.routes import register_assets_system
from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes
from app.user_manager import UserManager
from app.model_manager import ModelFileManager
@ -198,10 +198,6 @@ class PromptServer():
if loop is None:
loop = asyncio.get_event_loop()
mimetypes.init()
mimetypes.add_type('application/javascript; charset=utf-8', '.js')
mimetypes.add_type('image/webp', '.webp')
self.user_manager = UserManager()
self.model_file_manager = ModelFileManager()
self.custom_node_manager = CustomNodeManager()
@ -240,7 +236,11 @@ class PromptServer():
else args.front_end_root
)
logging.info(f"[Prompt Server] web root: {self.web_root}")
register_assets_system(self.app, self.user_manager)
if args.enable_assets:
register_assets_routes(self.app, self.user_manager)
else:
register_assets_routes(self.app)
asset_seeder.disable()
routes = web.RouteTableDef()
self.routes = routes
self.last_node_id = None
@ -698,10 +698,7 @@ class PromptServer():
@routes.get("/object_info")
async def get_object_info(request):
try:
seed_assets(["models"])
except Exception as e:
logging.error(f"Failed to seed assets: {e}")
asset_seeder.start(roots=("models", "input", "output"))
with folder_paths.cache_helper:
out = {}
for x in nodes.NODE_CLASS_MAPPINGS:

View File

@ -108,7 +108,7 @@ def comfy_url_and_proc(comfy_tmp_base_dir: Path, request: pytest.FixtureRequest)
"main.py",
f"--base-directory={str(comfy_tmp_base_dir)}",
f"--database-url={db_url}",
"--disable-assets-autoscan",
"--enable-assets",
"--listen",
"127.0.0.1",
"--port",
@ -212,7 +212,7 @@ def asset_factory(http: requests.Session, api_base: str):
for aid in created:
with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30)
@pytest.fixture
@ -258,14 +258,4 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str):
break
for aid in ids:
with contextlib.suppress(Exception):
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a fast sync/seed pass by calling the seed endpoint."""
session.post(base_url + "/api/assets/seed", json={"roots": ["models", "input", "output"]}, timeout=30)
time.sleep(0.2)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension
http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30)

View File

@ -0,0 +1,28 @@
"""Helper functions for assets integration tests."""
import time
import requests
def trigger_sync_seed_assets(session: requests.Session, base_url: str) -> None:
"""Force a synchronous sync/seed pass by calling the seed endpoint with wait=true.
Retries on 409 (already running) until the previous scan finishes.
"""
deadline = time.monotonic() + 60
while True:
r = session.post(
base_url + "/api/assets/seed?wait=true",
json={"roots": ["models", "input", "output"]},
timeout=60,
)
if r.status_code != 409:
assert r.status_code == 200, f"seed endpoint returned {r.status_code}: {r.text}"
return
if time.monotonic() > deadline:
raise TimeoutError("seed endpoint stuck in 409 (already running)")
time.sleep(0.25)
def get_asset_filename(asset_hash: str, extension: str) -> str:
return asset_hash.removeprefix("blake3:") + extension

View File

@ -0,0 +1,20 @@
import pytest
from sqlalchemy import create_engine
from sqlalchemy.orm import Session
from app.assets.database.models import Base
@pytest.fixture
def session():
"""In-memory SQLite session for fast unit tests."""
engine = create_engine("sqlite:///:memory:")
Base.metadata.create_all(engine)
with Session(engine) as sess:
yield sess
@pytest.fixture(autouse=True)
def autoclean_unit_test_assets():
"""Override parent autouse fixture - query tests don't need server cleanup."""
yield

View File

@ -0,0 +1,144 @@
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.helpers import get_utc_now
from app.assets.database.models import Asset
from app.assets.database.queries import (
asset_exists_by_hash,
get_asset_by_hash,
upsert_asset,
bulk_insert_assets,
)
class TestAssetExistsByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,expected",
[
(None, "nonexistent", False), # No asset exists
("blake3:abc123", "blake3:abc123", True), # Asset exists with matching hash
(None, "", False), # Null hash in DB doesn't match empty string
],
ids=["nonexistent", "existing", "null_hash_no_match"],
)
def test_exists_by_hash(self, session: Session, setup_hash, query_hash, expected):
if setup_hash is not None or query_hash == "":
asset = Asset(hash=setup_hash, size_bytes=100)
session.add(asset)
session.commit()
assert asset_exists_by_hash(session, asset_hash=query_hash) is expected
class TestGetAssetByHash:
@pytest.mark.parametrize(
"setup_hash,query_hash,should_find",
[
(None, "nonexistent", False),
("blake3:def456", "blake3:def456", True),
],
ids=["nonexistent", "existing"],
)
def test_get_by_hash(self, session: Session, setup_hash, query_hash, should_find):
if setup_hash is not None:
asset = Asset(hash=setup_hash, size_bytes=200, mime_type="image/png")
session.add(asset)
session.commit()
result = get_asset_by_hash(session, asset_hash=query_hash)
if should_find:
assert result is not None
assert result.size_bytes == 200
assert result.mime_type == "image/png"
else:
assert result is None
class TestUpsertAsset:
@pytest.mark.parametrize(
"first_size,first_mime,second_size,second_mime,expect_created,expect_updated,final_size,final_mime",
[
# New asset creation
(None, None, 1024, "application/octet-stream", True, False, 1024, "application/octet-stream"),
# Existing asset, same values - no update
(500, "text/plain", 500, "text/plain", False, False, 500, "text/plain"),
# Existing asset with size 0, update with new values
(0, None, 2048, "image/png", False, True, 2048, "image/png"),
# Existing asset, second call with size 0 - no update
(1000, None, 0, None, False, False, 1000, None),
],
ids=["new_asset", "existing_no_change", "update_from_zero", "zero_size_no_update"],
)
def test_upsert_scenarios(
self,
session: Session,
first_size,
first_mime,
second_size,
second_mime,
expect_created,
expect_updated,
final_size,
final_mime,
):
asset_hash = f"blake3:test_{first_size}_{second_size}"
# First upsert (if first_size is not None, we're testing the second call)
if first_size is not None:
upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=first_size,
mime_type=first_mime,
)
session.commit()
# The upsert call we're testing
asset, created, updated = upsert_asset(
session,
asset_hash=asset_hash,
size_bytes=second_size,
mime_type=second_mime,
)
session.commit()
assert created is expect_created
assert updated is expect_updated
assert asset.size_bytes == final_size
assert asset.mime_type == final_mime
class TestBulkInsertAssets:
def test_inserts_multiple_assets(self, session: Session):
now = get_utc_now()
rows = [
{"id": str(uuid.uuid4()), "hash": "blake3:bulk1", "size_bytes": 100, "mime_type": "text/plain", "created_at": now},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk2", "size_bytes": 200, "mime_type": "image/png", "created_at": now},
{"id": str(uuid.uuid4()), "hash": "blake3:bulk3", "size_bytes": 300, "mime_type": None, "created_at": now},
]
bulk_insert_assets(session, rows)
session.commit()
assets = session.query(Asset).all()
assert len(assets) == 3
hashes = {a.hash for a in assets}
assert hashes == {"blake3:bulk1", "blake3:bulk2", "blake3:bulk3"}
def test_empty_list_is_noop(self, session: Session):
bulk_insert_assets(session, [])
session.commit()
assert session.query(Asset).count() == 0
def test_handles_large_batch(self, session: Session):
"""Test chunking logic with more rows than MAX_BIND_PARAMS allows."""
now = get_utc_now()
rows = [
{"id": str(uuid.uuid4()), "hash": f"blake3:large{i}", "size_bytes": i, "mime_type": None, "created_at": now}
for i in range(200)
]
bulk_insert_assets(session, rows)
session.commit()
assert session.query(Asset).count() == 200

View File

@ -0,0 +1,517 @@
import time
import uuid
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference, AssetReferenceMeta
from app.assets.database.queries import (
reference_exists_for_asset_id,
get_reference_by_id,
insert_reference,
get_or_create_reference,
update_reference_timestamps,
list_references_page,
fetch_reference_asset_and_tags,
fetch_reference_and_asset,
update_reference_access_time,
set_reference_metadata,
delete_reference_by_id,
set_reference_preview,
bulk_insert_references_ignore_conflicts,
get_reference_ids_by_ids,
ensure_tags_exist,
add_tags_to_reference,
)
from app.assets.helpers import get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size, mime_type="application/octet-stream")
session.add(asset)
session.flush()
return asset
def _make_reference(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
class TestReferenceExistsForAssetId:
def test_returns_false_when_no_reference(self, session: Session):
asset = _make_asset(session, "hash1")
assert reference_exists_for_asset_id(session, asset_id=asset.id) is False
def test_returns_true_when_reference_exists(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset)
assert reference_exists_for_asset_id(session, asset_id=asset.id) is True
class TestGetReferenceById:
def test_returns_none_for_nonexistent(self, session: Session):
assert get_reference_by_id(session, reference_id="nonexistent") is None
def test_returns_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="myfile.txt")
result = get_reference_by_id(session, reference_id=ref.id)
assert result is not None
assert result.name == "myfile.txt"
class TestListReferencesPage:
def test_empty_db(self, session: Session):
refs, tag_map, total = list_references_page(session)
assert refs == []
assert tag_map == {}
assert total == 0
def test_returns_references_with_tags(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="test.bin")
ensure_tags_exist(session, ["alpha", "beta"])
add_tags_to_reference(session, reference_id=ref.id, tags=["alpha", "beta"])
session.commit()
refs, tag_map, total = list_references_page(session)
assert len(refs) == 1
assert refs[0].id == ref.id
assert set(tag_map[ref.id]) == {"alpha", "beta"}
assert total == 1
def test_name_contains_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="model_v1.safetensors")
_make_reference(session, asset, name="config.json")
session.commit()
refs, _, total = list_references_page(session, name_contains="model")
assert total == 1
assert refs[0].name == "model_v1.safetensors"
def test_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="public", owner_id="")
_make_reference(session, asset, name="private", owner_id="user1")
session.commit()
# Empty owner sees only public
refs, _, total = list_references_page(session, owner_id="")
assert total == 1
assert refs[0].name == "public"
# Owner sees both
refs, _, total = list_references_page(session, owner_id="user1")
assert total == 2
def test_include_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, name="tagged")
_make_reference(session, asset, name="untagged")
ensure_tags_exist(session, ["wanted"])
add_tags_to_reference(session, reference_id=ref1.id, tags=["wanted"])
session.commit()
refs, _, total = list_references_page(session, include_tags=["wanted"])
assert total == 1
assert refs[0].name == "tagged"
def test_exclude_tags_filter(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="keep")
ref_exclude = _make_reference(session, asset, name="exclude")
ensure_tags_exist(session, ["bad"])
add_tags_to_reference(session, reference_id=ref_exclude.id, tags=["bad"])
session.commit()
refs, _, total = list_references_page(session, exclude_tags=["bad"])
assert total == 1
assert refs[0].name == "keep"
def test_sorting(self, session: Session):
asset = _make_asset(session, "hash1", size=100)
asset2 = _make_asset(session, "hash2", size=500)
_make_reference(session, asset, name="small")
_make_reference(session, asset2, name="large")
session.commit()
refs, _, _ = list_references_page(session, sort="size", order="desc")
assert refs[0].name == "large"
refs, _, _ = list_references_page(session, sort="name", order="asc")
assert refs[0].name == "large"
class TestFetchReferenceAssetAndTags:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_reference_asset_and_tags(session, "nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, name="test.bin")
ensure_tags_exist(session, ["tag1"])
add_tags_to_reference(session, reference_id=ref.id, tags=["tag1"])
session.commit()
result = fetch_reference_asset_and_tags(session, ref.id)
assert result is not None
ret_ref, ret_asset, ret_tags = result
assert ret_ref.id == ref.id
assert ret_asset.id == asset.id
assert ret_tags == ["tag1"]
class TestFetchReferenceAndAsset:
def test_returns_none_for_nonexistent(self, session: Session):
result = fetch_reference_and_asset(session, reference_id="nonexistent")
assert result is None
def test_returns_tuple(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
result = fetch_reference_and_asset(session, reference_id=ref.id)
assert result is not None
ret_ref, ret_asset = result
assert ret_ref.id == ref.id
assert ret_asset.id == asset.id
class TestUpdateReferenceAccessTime:
def test_updates_last_access_time(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
original_time = ref.last_access_time
session.commit()
import time
time.sleep(0.01)
update_reference_access_time(session, reference_id=ref.id)
session.commit()
session.refresh(ref)
assert ref.last_access_time > original_time
class TestDeleteReferenceById:
def test_deletes_existing(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
result = delete_reference_by_id(session, reference_id=ref.id, owner_id="")
assert result is True
assert get_reference_by_id(session, reference_id=ref.id) is None
def test_returns_false_for_nonexistent(self, session: Session):
result = delete_reference_by_id(session, reference_id="nonexistent", owner_id="")
assert result is False
def test_respects_owner_visibility(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, owner_id="user1")
session.commit()
result = delete_reference_by_id(session, reference_id=ref.id, owner_id="user2")
assert result is False
assert get_reference_by_id(session, reference_id=ref.id) is not None
class TestSetReferencePreview:
def test_sets_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
session.commit()
set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id)
session.commit()
session.refresh(ref)
assert ref.preview_id == preview_asset.id
def test_clears_preview(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
ref.preview_id = preview_asset.id
session.commit()
set_reference_preview(session, reference_id=ref.id, preview_asset_id=None)
session.commit()
session.refresh(ref)
assert ref.preview_id is None
def test_raises_for_nonexistent_reference(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None)
def test_raises_for_nonexistent_preview(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
with pytest.raises(ValueError, match="Preview Asset"):
set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent")
class TestInsertReference:
def test_creates_new_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref = insert_reference(
session, asset_id=asset.id, owner_id="user1", name="test.bin"
)
session.commit()
assert ref is not None
assert ref.name == "test.bin"
assert ref.owner_id == "user1"
def test_allows_duplicate_names(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = insert_reference(session, asset_id=asset.id, owner_id="user1", name="dup.bin")
session.commit()
# Duplicate names are now allowed
ref2 = insert_reference(
session, asset_id=asset.id, owner_id="user1", name="dup.bin"
)
session.commit()
assert ref1 is not None
assert ref2 is not None
assert ref1.id != ref2.id
class TestGetOrCreateReference:
def test_creates_new_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref, created = get_or_create_reference(
session, asset_id=asset.id, owner_id="user1", name="new.bin"
)
session.commit()
assert created is True
assert ref.name == "new.bin"
def test_always_creates_new_reference(self, session: Session):
asset = _make_asset(session, "hash1")
ref1, created1 = get_or_create_reference(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
# Duplicate names are allowed, so always creates new
ref2, created2 = get_or_create_reference(
session, asset_id=asset.id, owner_id="user1", name="existing.bin"
)
session.commit()
assert created1 is True
assert created2 is True
assert ref1.id != ref2.id
class TestUpdateReferenceTimestamps:
def test_updates_timestamps(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
original_updated_at = ref.updated_at
session.commit()
time.sleep(0.01)
update_reference_timestamps(session, ref)
session.commit()
session.refresh(ref)
assert ref.updated_at > original_updated_at
def test_updates_preview_id(self, session: Session):
asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset)
session.commit()
update_reference_timestamps(session, ref, preview_id=preview_asset.id)
session.commit()
session.refresh(ref)
assert ref.preview_id == preview_asset.id
class TestSetReferenceMetadata:
def test_sets_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"key": "value"}
)
session.commit()
session.refresh(ref)
assert ref.user_metadata == {"key": "value"}
# Check metadata table
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
assert len(meta) == 1
assert meta[0].key == "key"
assert meta[0].val_str == "value"
def test_replaces_existing_metadata(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"old": "data"}
)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"new": "data"}
)
session.commit()
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
assert len(meta) == 1
assert meta[0].key == "new"
def test_clears_metadata_with_empty_dict(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={"key": "value"}
)
session.commit()
set_reference_metadata(
session, reference_id=ref.id, user_metadata={}
)
session.commit()
session.refresh(ref)
assert ref.user_metadata == {}
meta = session.query(AssetReferenceMeta).filter_by(asset_reference_id=ref.id).all()
assert len(meta) == 0
def test_raises_for_nonexistent(self, session: Session):
with pytest.raises(ValueError, match="not found"):
set_reference_metadata(
session, reference_id="nonexistent", user_metadata={"key": "value"}
)
class TestBulkInsertReferencesIgnoreConflicts:
def test_inserts_multiple_references(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk1.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "bulk2.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
refs = session.query(AssetReference).all()
assert len(refs) == 2
def test_allows_duplicate_names(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, name="existing.bin", owner_id="")
session.commit()
now = get_utc_now()
rows = [
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "existing.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"id": str(uuid.uuid4()),
"owner_id": "",
"name": "new.bin",
"asset_id": asset.id,
"preview_id": None,
"user_metadata": {},
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
# Duplicate names allowed, so all 3 rows exist
refs = session.query(AssetReference).all()
assert len(refs) == 3
def test_empty_list_is_noop(self, session: Session):
bulk_insert_references_ignore_conflicts(session, [])
assert session.query(AssetReference).count() == 0
class TestGetReferenceIdsByIds:
def test_returns_existing_ids(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, name="a.bin")
ref2 = _make_reference(session, asset, name="b.bin")
session.commit()
found = get_reference_ids_by_ids(session, [ref1.id, ref2.id, "nonexistent"])
assert found == {ref1.id, ref2.id}
def test_empty_list_returns_empty(self, session: Session):
found = get_reference_ids_by_ids(session, [])
assert found == set()

View File

@ -0,0 +1,499 @@
"""Tests for cache_state (AssetReference file path) query functions."""
import pytest
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries import (
list_references_by_asset_id,
upsert_reference,
get_unreferenced_unhashed_asset_ids,
delete_assets_by_ids,
get_references_for_prefixes,
bulk_update_needs_verify,
delete_references_by_ids,
delete_orphaned_seed_asset,
bulk_insert_references_ignore_conflicts,
get_references_by_paths_and_asset_ids,
mark_references_missing_outside_prefixes,
restore_references_by_paths,
)
from app.assets.helpers import select_best_live_path, get_utc_now
def _make_asset(session: Session, hash_val: str | None = None, size: int = 1024) -> Asset:
asset = Asset(hash=hash_val, size_bytes=size)
session.add(asset)
session.flush()
return asset
def _make_reference(
session: Session,
asset: Asset,
file_path: str,
name: str = "test",
mtime_ns: int | None = None,
needs_verify: bool = False,
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
asset_id=asset.id,
file_path=file_path,
name=name,
mtime_ns=mtime_ns,
needs_verify=needs_verify,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
class TestListReferencesByAssetId:
def test_returns_empty_for_no_references(self, session: Session):
asset = _make_asset(session, "hash1")
refs = list_references_by_asset_id(session, asset_id=asset.id)
assert list(refs) == []
def test_returns_references_for_asset(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/path/a.bin", name="a")
_make_reference(session, asset, "/path/b.bin", name="b")
session.commit()
refs = list_references_by_asset_id(session, asset_id=asset.id)
paths = [r.file_path for r in refs]
assert set(paths) == {"/path/a.bin", "/path/b.bin"}
def test_does_not_return_other_assets_references(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_reference(session, asset1, "/path/asset1.bin", name="a1")
_make_reference(session, asset2, "/path/asset2.bin", name="a2")
session.commit()
refs = list_references_by_asset_id(session, asset_id=asset1.id)
paths = [r.file_path for r in refs]
assert paths == ["/path/asset1.bin"]
class TestSelectBestLivePath:
def test_returns_empty_for_empty_list(self):
result = select_best_live_path([])
assert result == ""
def test_returns_empty_when_no_files_exist(self, session: Session):
asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset, "/nonexistent/path.bin")
session.commit()
result = select_best_live_path([ref])
assert result == ""
def test_prefers_verified_path(self, session: Session, tmp_path):
"""needs_verify=False should be preferred."""
asset = _make_asset(session, "hash1")
verified_file = tmp_path / "verified.bin"
verified_file.write_bytes(b"data")
unverified_file = tmp_path / "unverified.bin"
unverified_file.write_bytes(b"data")
ref_verified = _make_reference(
session, asset, str(verified_file), name="verified", needs_verify=False
)
ref_unverified = _make_reference(
session, asset, str(unverified_file), name="unverified", needs_verify=True
)
session.commit()
refs = [ref_unverified, ref_verified]
result = select_best_live_path(refs)
assert result == str(verified_file)
def test_falls_back_to_existing_unverified(self, session: Session, tmp_path):
"""If all references need verification, return first existing path."""
asset = _make_asset(session, "hash1")
existing_file = tmp_path / "exists.bin"
existing_file.write_bytes(b"data")
ref = _make_reference(session, asset, str(existing_file), needs_verify=True)
session.commit()
result = select_best_live_path([ref])
assert result == str(existing_file)
class TestSelectBestLivePathWithMocking:
def test_handles_missing_file_path_attr(self):
"""Gracefully handle references with None file_path."""
class MockRef:
file_path = None
needs_verify = False
result = select_best_live_path([MockRef()])
assert result == ""
class TestUpsertReference:
@pytest.mark.parametrize(
"initial_mtime,second_mtime,expect_created,expect_updated,final_mtime",
[
# New reference creation
(None, 12345, True, False, 12345),
# Existing reference, same mtime - no update
(100, 100, False, False, 100),
# Existing reference, different mtime - update
(100, 200, False, True, 200),
],
ids=["new_reference", "existing_no_change", "existing_update_mtime"],
)
def test_upsert_scenarios(
self, session: Session, initial_mtime, second_mtime, expect_created, expect_updated, final_mtime
):
asset = _make_asset(session, "hash1")
file_path = f"/path_{initial_mtime}_{second_mtime}.bin"
name = f"file_{initial_mtime}_{second_mtime}"
# Create initial reference if needed
if initial_mtime is not None:
upsert_reference(session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=initial_mtime)
session.commit()
# The upsert call we're testing
created, updated = upsert_reference(
session, asset_id=asset.id, file_path=file_path, name=name, mtime_ns=second_mtime
)
session.commit()
assert created is expect_created
assert updated is expect_updated
ref = session.query(AssetReference).filter_by(file_path=file_path).one()
assert ref.mtime_ns == final_mtime
def test_upsert_restores_missing_reference(self, session: Session):
"""Upserting a reference that was marked missing should restore it."""
asset = _make_asset(session, "hash1")
file_path = "/restored/file.bin"
ref = _make_reference(session, asset, file_path, mtime_ns=100)
ref.is_missing = True
session.commit()
created, updated = upsert_reference(
session, asset_id=asset.id, file_path=file_path, name="restored", mtime_ns=100
)
session.commit()
assert created is False
assert updated is True
restored_ref = session.query(AssetReference).filter_by(file_path=file_path).one()
assert restored_ref.is_missing is False
class TestRestoreReferencesByPaths:
def test_restores_missing_references(self, session: Session):
asset = _make_asset(session, "hash1")
missing_path = "/missing/file.bin"
active_path = "/active/file.bin"
missing_ref = _make_reference(session, asset, missing_path, name="missing")
missing_ref.is_missing = True
_make_reference(session, asset, active_path, name="active")
session.commit()
restored = restore_references_by_paths(session, [missing_path])
session.commit()
assert restored == 1
ref = session.query(AssetReference).filter_by(file_path=missing_path).one()
assert ref.is_missing is False
def test_empty_list_restores_nothing(self, session: Session):
restored = restore_references_by_paths(session, [])
assert restored == 0
class TestMarkReferencesMissingOutsidePrefixes:
def test_marks_references_missing_outside_prefixes(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
valid_dir = tmp_path / "valid"
valid_dir.mkdir()
invalid_dir = tmp_path / "invalid"
invalid_dir.mkdir()
valid_path = str(valid_dir / "file.bin")
invalid_path = str(invalid_dir / "file.bin")
_make_reference(session, asset, valid_path, name="valid")
_make_reference(session, asset, invalid_path, name="invalid")
session.commit()
marked = mark_references_missing_outside_prefixes(session, [str(valid_dir)])
session.commit()
assert marked == 1
all_refs = session.query(AssetReference).all()
assert len(all_refs) == 2
valid_ref = next(r for r in all_refs if r.file_path == valid_path)
invalid_ref = next(r for r in all_refs if r.file_path == invalid_path)
assert valid_ref.is_missing is False
assert invalid_ref.is_missing is True
def test_empty_prefixes_marks_nothing(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/some/path.bin")
session.commit()
marked = mark_references_missing_outside_prefixes(session, [])
assert marked == 0
class TestGetUnreferencedUnhashedAssetIds:
def test_returns_unreferenced_unhashed_assets(self, session: Session):
# Unhashed asset (hash=None) with no references (no file_path)
no_refs = _make_asset(session, hash_val=None)
# Unhashed asset with active reference (not unreferenced)
with_active_ref = _make_asset(session, hash_val=None)
_make_reference(session, with_active_ref, "/has/ref.bin", name="has_ref")
# Unhashed asset with only missing reference (should be unreferenced)
with_missing_ref = _make_asset(session, hash_val=None)
missing_ref = _make_reference(session, with_missing_ref, "/missing/ref.bin", name="missing_ref")
missing_ref.is_missing = True
# Regular asset (hash not None) - should not be returned
_make_asset(session, hash_val="blake3:regular")
session.commit()
unreferenced = get_unreferenced_unhashed_asset_ids(session)
assert no_refs.id in unreferenced
assert with_missing_ref.id in unreferenced
assert with_active_ref.id not in unreferenced
class TestDeleteAssetsByIds:
def test_deletes_assets_and_references(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/test/path.bin", name="test")
session.commit()
deleted = delete_assets_by_ids(session, [asset.id])
session.commit()
assert deleted == 1
assert session.query(Asset).count() == 0
assert session.query(AssetReference).count() == 0
def test_empty_list_deletes_nothing(self, session: Session):
_make_asset(session, "hash1")
session.commit()
deleted = delete_assets_by_ids(session, [])
assert deleted == 0
assert session.query(Asset).count() == 1
class TestGetReferencesForPrefixes:
def test_returns_references_matching_prefix(self, session: Session, tmp_path):
asset = _make_asset(session, "hash1")
dir1 = tmp_path / "dir1"
dir1.mkdir()
dir2 = tmp_path / "dir2"
dir2.mkdir()
path1 = str(dir1 / "file.bin")
path2 = str(dir2 / "file.bin")
_make_reference(session, asset, path1, name="file1", mtime_ns=100)
_make_reference(session, asset, path2, name="file2", mtime_ns=200)
session.commit()
rows = get_references_for_prefixes(session, [str(dir1)])
assert len(rows) == 1
assert rows[0].file_path == path1
def test_empty_prefixes_returns_empty(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/some/path.bin")
session.commit()
rows = get_references_for_prefixes(session, [])
assert rows == []
class TestBulkSetNeedsVerify:
def test_sets_needs_verify_flag(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, "/path1.bin", needs_verify=False)
ref2 = _make_reference(session, asset, "/path2.bin", needs_verify=False)
session.commit()
updated = bulk_update_needs_verify(session, [ref1.id, ref2.id], True)
session.commit()
assert updated == 2
session.refresh(ref1)
session.refresh(ref2)
assert ref1.needs_verify is True
assert ref2.needs_verify is True
def test_empty_list_updates_nothing(self, session: Session):
updated = bulk_update_needs_verify(session, [], True)
assert updated == 0
class TestDeleteReferencesByIds:
def test_deletes_references_by_id(self, session: Session):
asset = _make_asset(session, "hash1")
ref1 = _make_reference(session, asset, "/path1.bin")
_make_reference(session, asset, "/path2.bin")
session.commit()
deleted = delete_references_by_ids(session, [ref1.id])
session.commit()
assert deleted == 1
assert session.query(AssetReference).count() == 1
def test_empty_list_deletes_nothing(self, session: Session):
deleted = delete_references_by_ids(session, [])
assert deleted == 0
class TestDeleteOrphanedSeedAsset:
@pytest.mark.parametrize(
"create_asset,expected_deleted,expected_count",
[
(True, True, 0), # Existing asset gets deleted
(False, False, 0), # Nonexistent returns False
],
ids=["deletes_existing", "nonexistent_returns_false"],
)
def test_delete_orphaned_seed_asset(
self, session: Session, create_asset, expected_deleted, expected_count
):
asset_id = "nonexistent-id"
if create_asset:
asset = _make_asset(session, hash_val=None)
asset_id = asset.id
_make_reference(session, asset, "/test/path.bin", name="test")
session.commit()
deleted = delete_orphaned_seed_asset(session, asset_id)
if create_asset:
session.commit()
assert deleted is expected_deleted
assert session.query(Asset).count() == expected_count
class TestBulkInsertReferencesIgnoreConflicts:
def test_inserts_multiple_references(self, session: Session):
asset = _make_asset(session, "hash1")
now = get_utc_now()
rows = [
{
"asset_id": asset.id,
"file_path": "/bulk1.bin",
"name": "bulk1",
"mtime_ns": 100,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"asset_id": asset.id,
"file_path": "/bulk2.bin",
"name": "bulk2",
"mtime_ns": 200,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetReference).count() == 2
def test_ignores_conflicts(self, session: Session):
asset = _make_asset(session, "hash1")
_make_reference(session, asset, "/existing.bin", mtime_ns=100)
session.commit()
now = get_utc_now()
rows = [
{
"asset_id": asset.id,
"file_path": "/existing.bin",
"name": "existing",
"mtime_ns": 999,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
{
"asset_id": asset.id,
"file_path": "/new.bin",
"name": "new",
"mtime_ns": 200,
"created_at": now,
"updated_at": now,
"last_access_time": now,
},
]
bulk_insert_references_ignore_conflicts(session, rows)
session.commit()
assert session.query(AssetReference).count() == 2
existing = session.query(AssetReference).filter_by(file_path="/existing.bin").one()
assert existing.mtime_ns == 100 # Original value preserved
def test_empty_list_is_noop(self, session: Session):
bulk_insert_references_ignore_conflicts(session, [])
assert session.query(AssetReference).count() == 0
class TestGetReferencesByPathsAndAssetIds:
def test_returns_matching_paths(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_reference(session, asset1, "/path1.bin")
_make_reference(session, asset2, "/path2.bin")
session.commit()
path_to_asset = {
"/path1.bin": asset1.id,
"/path2.bin": asset2.id,
}
winners = get_references_by_paths_and_asset_ids(session, path_to_asset)
assert winners == {"/path1.bin", "/path2.bin"}
def test_excludes_non_matching_asset_ids(self, session: Session):
asset1 = _make_asset(session, "hash1")
asset2 = _make_asset(session, "hash2")
_make_reference(session, asset1, "/path1.bin")
session.commit()
# Path exists but with different asset_id
path_to_asset = {"/path1.bin": asset2.id}
winners = get_references_by_paths_and_asset_ids(session, path_to_asset)
assert winners == set()
def test_empty_dict_returns_empty(self, session: Session):
winners = get_references_by_paths_and_asset_ids(session, {})
assert winners == set()

Some files were not shown because too many files have changed in this diff Show More