Merge branch 'master' into mark-dtype-advanced
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

This commit is contained in:
Christian Byrne 2026-03-17 07:02:45 -07:00 committed by GitHub
commit 782f09da4b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
56 changed files with 1957 additions and 444 deletions

103
.github/scripts/check-ai-co-authors.sh vendored Executable file
View File

@ -0,0 +1,103 @@
#!/usr/bin/env bash
# Checks pull request commits for AI agent Co-authored-by trailers.
# Exits non-zero when any are found and prints fix instructions.
set -euo pipefail
base_sha="${1:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
head_sha="${2:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
# Known AI coding-agent trailer patterns (case-insensitive).
# Each entry is an extended-regex fragment matched against Co-authored-by lines.
AGENT_PATTERNS=(
# Anthropic — Claude Code / Amp
'noreply@anthropic\.com'
# Cursor
'cursoragent@cursor\.com'
# GitHub Copilot
'copilot-swe-agent\[bot\]'
'copilot@github\.com'
# OpenAI Codex
'noreply@openai\.com'
'codex@openai\.com'
# Aider
'aider@aider\.chat'
# Google — Gemini / Jules
'gemini@google\.com'
'jules@google\.com'
# Windsurf / Codeium
'@codeium\.com'
# Devin
'devin-ai-integration\[bot\]'
'devin@cognition\.ai'
'devin@cognition-labs\.com'
# Amazon Q Developer
'amazon-q-developer'
'@amazon\.com.*[Qq].[Dd]eveloper'
# Cline
'cline-bot'
'cline@cline\.ai'
# Continue
'continue-agent'
'continue@continue\.dev'
# Sourcegraph
'noreply@sourcegraph\.com'
# Generic catch-alls for common agent name patterns
'Co-authored-by:.*\b[Cc]laude\b'
'Co-authored-by:.*\b[Cc]opilot\b'
'Co-authored-by:.*\b[Cc]ursor\b'
'Co-authored-by:.*\b[Cc]odex\b'
'Co-authored-by:.*\b[Gg]emini\b'
'Co-authored-by:.*\b[Aa]ider\b'
'Co-authored-by:.*\b[Dd]evin\b'
'Co-authored-by:.*\b[Ww]indsurf\b'
'Co-authored-by:.*\b[Cc]line\b'
'Co-authored-by:.*\b[Aa]mazon Q\b'
'Co-authored-by:.*\b[Jj]ules\b'
'Co-authored-by:.*\bOpenCode\b'
)
# Build a single alternation regex from all patterns.
regex=""
for pattern in "${AGENT_PATTERNS[@]}"; do
if [[ -n "$regex" ]]; then
regex="${regex}|${pattern}"
else
regex="$pattern"
fi
done
# Collect Co-authored-by lines from every commit in the PR range.
violations=""
while IFS= read -r sha; do
message="$(git log -1 --format='%B' "$sha")"
matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
if [[ -z "$matched_lines" ]]; then
continue
fi
while IFS= read -r line; do
if echo "$line" | grep -iqE "$regex"; then
short="$(git log -1 --format='%h' "$sha")"
violations="${violations} ${short}: ${line}"$'\n'
fi
done <<< "$matched_lines"
done < <(git rev-list "${base_sha}..${head_sha}")
if [[ -n "$violations" ]]; then
echo "::error::AI agent Co-authored-by trailers detected in PR commits."
echo ""
echo "The following commits contain Co-authored-by trailers from AI coding agents:"
echo ""
echo "$violations"
echo "These trailers should be removed before merging."
echo ""
echo "To fix, rewrite the commit messages with:"
echo " git rebase -i ${base_sha}"
echo ""
echo "and remove the Co-authored-by lines, then force-push your branch."
echo ""
echo "If you believe this is a false positive, please open an issue."
exit 1
fi
echo "No AI agent Co-authored-by trailers found."

View File

@ -0,0 +1,19 @@
name: Check AI Co-Authors
on:
pull_request:
branches: ['*']
jobs:
check-ai-co-authors:
name: Check for AI agent co-author trailers
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Check commits for AI co-author trailers
run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"

View File

@ -38,6 +38,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
## Get Started ## Get Started
### Local
#### [Desktop Application](https://www.comfy.org/download) #### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started. - The easiest way to get started.
- Available on Windows & macOS. - Available on Windows & macOS.
@ -49,8 +51,13 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
#### [Manual Install](#manual-install-windows-linux) #### [Manual Install](#manual-install-windows-linux)
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend). Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/) ### Cloud
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
#### [Comfy Cloud](https://www.comfy.org/cloud)
- Our official paid cloud version for those who can't afford local hardware.
## Examples
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.

View File

@ -8,7 +8,7 @@ from alembic import context
config = context.config config = context.config
from app.database.models import Base from app.database.models import Base, NAMING_CONVENTION
target_metadata = Base.metadata target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
@ -51,7 +51,10 @@ def run_migrations_online() -> None:
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(
connection=connection, target_metadata=target_metadata connection=connection,
target_metadata=target_metadata,
render_as_batch=True,
naming_convention=NAMING_CONVENTION,
) )
with context.begin_transaction(): with context.begin_transaction():

View File

@ -0,0 +1,98 @@
"""
Add system_metadata and job_id columns to asset_references.
Change preview_id FK from assets.id to asset_references.id.
Revision ID: 0003_add_metadata_job_id
Revises: 0002_merge_to_asset_references
Create Date: 2026-03-09
"""
from alembic import op
import sqlalchemy as sa
from app.database.models import NAMING_CONVENTION
revision = "0003_add_metadata_job_id"
down_revision = "0002_merge_to_asset_references"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.add_column(
sa.Column("system_metadata", sa.JSON(), nullable=True)
)
batch_op.add_column(
sa.Column("job_id", sa.String(length=36), nullable=True)
)
# Change preview_id FK from assets.id to asset_references.id (self-ref).
# Existing values are asset-content IDs that won't match reference IDs,
# so null them out first.
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_constraint(
"fk_asset_references_preview_id_assets", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_asset_references",
"asset_references",
["preview_id"],
["id"],
ondelete="SET NULL",
)
batch_op.create_index(
"ix_asset_references_preview_id", ["preview_id"]
)
# Purge any all-null meta rows before adding the constraint
op.execute(
"DELETE FROM asset_reference_meta"
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
)
with op.batch_alter_table("asset_reference_meta") as batch_op:
batch_op.create_check_constraint(
"ck_asset_reference_meta_has_value",
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
)
def downgrade() -> None:
# SQLite doesn't reflect CHECK constraints, so we must declare it
# explicitly via table_args for the batch recreate to find it.
# Use the fully-rendered constraint name to avoid the naming convention
# doubling the prefix.
with op.batch_alter_table(
"asset_reference_meta",
table_args=[
sa.CheckConstraint(
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
name="ck_asset_reference_meta_has_value",
),
],
) as batch_op:
batch_op.drop_constraint(
"ck_asset_reference_meta_has_value", type_="check"
)
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_index("ix_asset_references_preview_id")
batch_op.drop_constraint(
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_assets",
"assets",
["preview_id"],
["id"],
ondelete="SET NULL",
)
with op.batch_alter_table("asset_references") as batch_op:
batch_op.drop_column("job_id")
batch_op.drop_column("system_metadata")

View File

@ -13,6 +13,7 @@ from pydantic import ValidationError
import folder_paths import folder_paths
from app import user_manager from app import user_manager
from app.assets.api import schemas_in, schemas_out from app.assets.api import schemas_in, schemas_out
from app.assets.services import schemas
from app.assets.api.schemas_in import ( from app.assets.api.schemas_in import (
AssetValidationError, AssetValidationError,
UploadError, UploadError,
@ -38,6 +39,7 @@ from app.assets.services import (
update_asset_metadata, update_asset_metadata,
upload_from_temp_path, upload_from_temp_path,
) )
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef() ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None USER_MANAGER: user_manager.UserManager | None = None
@ -122,6 +124,61 @@ def _validate_sort_field(requested: str | None) -> str:
return "created_at" return "created_at"
def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None:
"""Build a /api/view preview URL from asset tags and user_metadata filename."""
if not user_metadata:
return None
filename = user_metadata.get("filename")
if not filename:
return None
if "input" in tags:
view_type = "input"
elif "output" in tags:
view_type = "output"
else:
return None
subfolder = ""
if "/" in filename:
subfolder, filename = filename.rsplit("/", 1)
encoded_filename = urllib.parse.quote(filename, safe="")
url = f"/api/view?type={view_type}&filename={encoded_filename}"
if subfolder:
url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}"
return url
def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
"""Build an Asset response from a service result."""
if result.ref.preview_id:
preview_detail = get_asset_detail(result.ref.preview_id)
if preview_detail:
preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata)
else:
preview_url = None
else:
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
preview_url=preview_url,
preview_id=result.ref.preview_id,
user_metadata=result.ref.user_metadata or {},
metadata=result.ref.system_metadata,
job_id=result.ref.job_id,
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
created_at=result.ref.created_at,
updated_at=result.ref.updated_at,
last_access_time=result.ref.last_access_time,
)
@ROUTES.head("/api/assets/hash/{hash}") @ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled @_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response: async def head_asset_by_hash(request: web.Request) -> web.Response:
@ -164,20 +221,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
order=order, order=order,
) )
summaries = [ summaries = [_build_asset_response(item) for item in result.items]
schemas_out.AssetSummary(
id=item.ref.id,
name=item.ref.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes) if item.asset else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.ref.created_at,
updated_at=item.ref.updated_at,
last_access_time=item.ref.last_access_time,
)
for item in result.items
]
payload = schemas_out.AssetsList( payload = schemas_out.AssetsList(
assets=summaries, assets=summaries,
@ -207,18 +251,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
{"id": reference_id}, {"id": reference_id},
) )
payload = schemas_out.AssetDetail( payload = _build_asset_response(result)
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
)
except ValueError as e: except ValueError as e:
return _build_error_response( return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id} 404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
@ -230,7 +263,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
USER_MANAGER.get_request_user_id(request), USER_MANAGER.get_request_user_id(request),
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
@ -312,32 +345,31 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
400, "INVALID_JSON", "Request body must be valid JSON." 400, "INVALID_JSON", "Request body must be valid JSON."
) )
# Derive name from hash if not provided
name = body.name
if name is None:
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
result = create_from_hash( result = create_from_hash(
hash_str=body.hash, hash_str=body.hash,
name=body.name, name=name,
tags=body.tags, tags=body.tags,
user_metadata=body.user_metadata, user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request), owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
preview_id=body.preview_id,
) )
if result is None: if result is None:
return _build_error_response( return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
) )
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated( payload_out = schemas_out.AssetCreated(
id=result.ref.id, **asset.model_dump(),
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new, created_new=result.created_new,
) )
return web.json_response(payload_out.model_dump(mode="json"), status=201) return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201)
@ROUTES.post("/api/assets") @ROUTES.post("/api/assets")
@ -358,6 +390,8 @@ async def upload_asset(request: web.Request) -> web.Response:
"name": parsed.provided_name, "name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw, "user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash, "hash": parsed.provided_hash,
"mime_type": parsed.provided_mime_type,
"preview_id": parsed.provided_preview_id,
} }
) )
except ValidationError as ve: except ValidationError as ve:
@ -386,6 +420,8 @@ async def upload_asset(request: web.Request) -> web.Response:
tags=spec.tags, tags=spec.tags,
user_metadata=spec.user_metadata or {}, user_metadata=spec.user_metadata or {},
owner_id=owner_id, owner_id=owner_id,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
) )
if result is None: if result is None:
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
@ -410,6 +446,8 @@ async def upload_asset(request: web.Request) -> web.Response:
client_filename=parsed.file_client_name, client_filename=parsed.file_client_name,
owner_id=owner_id, owner_id=owner_id,
expected_hash=spec.hash, expected_hash=spec.hash,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
) )
except AssetValidationError as e: except AssetValidationError as e:
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
@ -428,21 +466,13 @@ async def upload_asset(request: web.Request) -> web.Response:
logging.exception("upload_asset failed for owner_id=%s", owner_id) logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
payload = schemas_out.AssetCreated( asset = _build_asset_response(result)
id=result.ref.id, payload_out = schemas_out.AssetCreated(
name=result.ref.name, **asset.model_dump(),
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new, created_new=result.created_new,
) )
status = 201 if result.created_new else 200 status = 201 if result.created_new else 200
return web.json_response(payload.model_dump(mode="json"), status=status) return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@ -464,15 +494,9 @@ async def update_asset_route(request: web.Request) -> web.Response:
name=body.name, name=body.name,
user_metadata=body.user_metadata, user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request), owner_id=USER_MANAGER.get_request_user_id(request),
preview_id=body.preview_id,
) )
payload = schemas_out.AssetUpdated( payload = _build_asset_response(result)
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
updated_at=result.ref.updated_at,
)
except PermissionError as pe: except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
except ValueError as ve: except ValueError as ve:
@ -486,7 +510,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
USER_MANAGER.get_request_user_id(request), USER_MANAGER.get_request_user_id(request),
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
@ -555,7 +579,7 @@ async def get_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsList( payload = schemas_out.TagsList(
tags=tags, total=total, has_more=(query.offset + len(tags)) < total tags=tags, total=total, has_more=(query.offset + len(tags)) < total
) )
return web.json_response(payload.model_dump(mode="json")) return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") @ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
@ -603,7 +627,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
@ -650,7 +674,29 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.get("/api/assets/tags/refine")
@_require_assets_feature_enabled
async def get_tags_refine(request: web.Request) -> web.Response:
"""GET request to get tag histogram for filtered assets."""
query_dict = get_query_dict(request)
try:
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
except ValidationError as ve:
return _build_validation_error_response("INVALID_QUERY", ve)
tag_counts = list_tag_histogram(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
)
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.post("/api/assets/seed") @ROUTES.post("/api/assets/seed")

View File

@ -45,6 +45,8 @@ class ParsedUpload:
user_metadata_raw: str | None user_metadata_raw: str | None
provided_hash: str | None provided_hash: str | None
provided_hash_exists: bool | None provided_hash_exists: bool | None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
class ListAssetsQuery(BaseModel): class ListAssetsQuery(BaseModel):
@ -98,11 +100,17 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel): class UpdateAssetBody(BaseModel):
name: str | None = None name: str | None = None
user_metadata: dict[str, Any] | None = None user_metadata: dict[str, Any] | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
@model_validator(mode="after") @model_validator(mode="after")
def _validate_at_least_one_field(self): def _validate_at_least_one_field(self):
if self.name is None and self.user_metadata is None: if all(
raise ValueError("Provide at least one of: name, user_metadata.") v is None
for v in (self.name, self.user_metadata, self.preview_id)
):
raise ValueError(
"Provide at least one of: name, user_metadata, preview_id."
)
return self return self
@ -110,9 +118,11 @@ class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str hash: str
name: str name: str | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict) user_metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
@field_validator("hash") @field_validator("hash")
@classmethod @classmethod
@ -138,6 +148,44 @@ class CreateFromHashBody(BaseModel):
return [] return []
class TagsRefineQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel): class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@ -186,21 +234,25 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel): class UploadAssetSpec(BaseModel):
"""Upload Asset operation. """Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output'); - tags: optional list; if provided, first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category if root == 'models', second must be a valid category
- name: display name - name: display name
- user_metadata: arbitrary JSON object (optional) - user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path - hash: optional canonical 'blake3:<hex>' for validation / fast-path
- mime_type: optional MIME type override
- preview_id: optional asset_reference ID for preview
Files are stored using the content hash as filename stem. Files are stored using the content hash as filename stem.
""" """
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1) tags: list[str] = Field(default_factory=list)
name: str | None = Field(default=None, max_length=512, description="Display Name") name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict) user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None) hash: str | None = Field(default=None)
mime_type: str | None = Field(default=None)
preview_id: str | None = Field(default=None) # references an asset_reference id
@field_validator("hash", mode="before") @field_validator("hash", mode="before")
@classmethod @classmethod
@ -279,7 +331,7 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def _validate_order(self): def _validate_order(self):
if not self.tags: if not self.tags:
raise ValueError("tags must be provided and non-empty") raise ValueError("at least one tag is required for uploads")
root = self.tags[0] root = self.tags[0]
if root not in {"models", "input", "output"}: if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output") raise ValueError("first tag must be one of: models, input, output")

View File

@ -4,7 +4,10 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel): class Asset(BaseModel):
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
``id`` here is the AssetReference id, not the content-addressed Asset id."""
id: str id: str
name: str name: str
asset_hash: str | None = None asset_hash: str | None = None
@ -12,8 +15,14 @@ class AssetSummary(BaseModel):
mime_type: str | None = None mime_type: str | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
preview_url: str | None = None preview_url: str | None = None
created_at: datetime | None = None preview_id: str | None = None # references an asset_reference id, not an asset id
updated_at: datetime | None = None user_metadata: dict[str, Any] = Field(default_factory=dict)
is_immutable: bool = False
metadata: dict[str, Any] | None = None
job_id: str | None = None
prompt_id: str | None = None # deprecated: use job_id
created_at: datetime
updated_at: datetime
last_access_time: datetime | None = None last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -23,50 +32,16 @@ class AssetSummary(BaseModel):
return v.isoformat() if v else None return v.isoformat() if v else None
class AssetCreated(Asset):
created_new: bool
class AssetsList(BaseModel): class AssetsList(BaseModel):
assets: list[AssetSummary] assets: list[Asset]
total: int total: int
has_more: bool has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _serialize_updated_at(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel): class TagUsage(BaseModel):
name: str name: str
count: int count: int
@ -91,3 +66,7 @@ class TagsRemove(BaseModel):
removed: list[str] = Field(default_factory=list) removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list) not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list) total_tags: list[str] = Field(default_factory=list)
class TagHistogram(BaseModel):
tag_counts: dict[str, int]

View File

@ -52,6 +52,8 @@ async def parse_multipart_upload(
user_metadata_raw: str | None = None user_metadata_raw: str | None = None
provided_hash: str | None = None provided_hash: str | None = None
provided_hash_exists: bool | None = None provided_hash_exists: bool | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
file_written = 0 file_written = 0
tmp_path: str | None = None tmp_path: str | None = None
@ -128,6 +130,16 @@ async def parse_multipart_upload(
provided_name = (await field.text()) or None provided_name = (await field.text()) or None
elif fname == "user_metadata": elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None user_metadata_raw = (await field.text()) or None
elif fname == "id":
raise UploadError(
400,
"UNSUPPORTED_FIELD",
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
)
elif fname == "mime_type":
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists): if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError( raise UploadError(
@ -152,6 +164,8 @@ async def parse_multipart_upload(
user_metadata_raw=user_metadata_raw, user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash, provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists, provided_hash_exists=provided_hash_exists,
provided_mime_type=provided_mime_type,
provided_preview_id=provided_preview_id,
) )

View File

@ -45,13 +45,7 @@ class Asset(Base):
passive_deletes=True, passive_deletes=True,
) )
preview_of: Mapped[list[AssetReference]] = relationship( # preview_id on AssetReference is a self-referential FK to asset_references.id
"AssetReference",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
foreign_keys=lambda: [AssetReference.preview_id],
viewonly=True,
)
__table_args__ = ( __table_args__ = (
Index("uq_assets_hash", "hash", unique=True), Index("uq_assets_hash", "hash", unique=True),
@ -91,11 +85,15 @@ class AssetReference(Base):
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False) name: Mapped[str] = mapped_column(String(512), nullable=False)
preview_id: Mapped[str | None] = mapped_column( preview_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="SET NULL") String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
) )
user_metadata: Mapped[dict[str, Any] | None] = mapped_column( user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True) JSON(none_as_null=True)
) )
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True), nullable=True, default=None
)
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now DateTime(timezone=False), nullable=False, default=get_utc_now
) )
@ -115,10 +113,10 @@ class AssetReference(Base):
foreign_keys=[asset_id], foreign_keys=[asset_id],
lazy="selectin", lazy="selectin",
) )
preview_asset: Mapped[Asset | None] = relationship( preview_ref: Mapped[AssetReference | None] = relationship(
"Asset", "AssetReference",
back_populates="preview_of",
foreign_keys=[preview_id], foreign_keys=[preview_id],
remote_side=lambda: [AssetReference.id],
) )
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship( metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
@ -152,6 +150,7 @@ class AssetReference(Base):
Index("ix_asset_references_created_at", "created_at"), Index("ix_asset_references_created_at", "created_at"),
Index("ix_asset_references_last_access_time", "last_access_time"), Index("ix_asset_references_last_access_time", "last_access_time"),
Index("ix_asset_references_deleted_at", "deleted_at"), Index("ix_asset_references_deleted_at", "deleted_at"),
Index("ix_asset_references_preview_id", "preview_id"),
Index("ix_asset_references_owner_name", "owner_id", "name"), Index("ix_asset_references_owner_name", "owner_id", "name"),
CheckConstraint( CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
@ -192,6 +191,10 @@ class AssetReferenceMeta(Base):
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"), Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"), Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"), Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
CheckConstraint(
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
name="has_value",
),
) )

View File

@ -31,16 +31,21 @@ from app.assets.database.queries.asset_reference import (
get_unenriched_references, get_unenriched_references,
get_unreferenced_unhashed_asset_ids, get_unreferenced_unhashed_asset_ids,
insert_reference, insert_reference,
list_all_file_paths_by_asset_id,
list_references_by_asset_id, list_references_by_asset_id,
list_references_page, list_references_page,
mark_references_missing_outside_prefixes, mark_references_missing_outside_prefixes,
rebuild_metadata_projection,
reference_exists,
reference_exists_for_asset_id, reference_exists_for_asset_id,
restore_references_by_paths, restore_references_by_paths,
set_reference_metadata, set_reference_metadata,
set_reference_preview, set_reference_preview,
set_reference_system_metadata,
soft_delete_reference_by_id, soft_delete_reference_by_id,
update_reference_access_time, update_reference_access_time,
update_reference_name, update_reference_name,
update_is_missing_by_asset_id,
update_reference_timestamps, update_reference_timestamps,
update_reference_updated_at, update_reference_updated_at,
upsert_reference, upsert_reference,
@ -54,6 +59,7 @@ from app.assets.database.queries.tags import (
bulk_insert_tags_and_meta, bulk_insert_tags_and_meta,
ensure_tags_exist, ensure_tags_exist,
get_reference_tags, get_reference_tags,
list_tag_counts_for_filtered_assets,
list_tags_with_usage, list_tags_with_usage,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
remove_tags_from_reference, remove_tags_from_reference,
@ -97,20 +103,26 @@ __all__ = [
"get_unenriched_references", "get_unenriched_references",
"get_unreferenced_unhashed_asset_ids", "get_unreferenced_unhashed_asset_ids",
"insert_reference", "insert_reference",
"list_all_file_paths_by_asset_id",
"list_references_by_asset_id", "list_references_by_asset_id",
"list_references_page", "list_references_page",
"list_tag_counts_for_filtered_assets",
"list_tags_with_usage", "list_tags_with_usage",
"mark_references_missing_outside_prefixes", "mark_references_missing_outside_prefixes",
"reassign_asset_references", "reassign_asset_references",
"rebuild_metadata_projection",
"reference_exists",
"reference_exists_for_asset_id", "reference_exists_for_asset_id",
"remove_missing_tag_for_asset_id", "remove_missing_tag_for_asset_id",
"remove_tags_from_reference", "remove_tags_from_reference",
"restore_references_by_paths", "restore_references_by_paths",
"set_reference_metadata", "set_reference_metadata",
"set_reference_preview", "set_reference_preview",
"set_reference_system_metadata",
"soft_delete_reference_by_id", "soft_delete_reference_by_id",
"set_reference_tags", "set_reference_tags",
"update_asset_hash_and_mime", "update_asset_hash_and_mime",
"update_is_missing_by_asset_id",
"update_reference_access_time", "update_reference_access_time",
"update_reference_name", "update_reference_name",
"update_reference_timestamps", "update_reference_timestamps",

View File

@ -69,7 +69,7 @@ def upsert_asset(
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes) asset.size_bytes = int(size_bytes)
changed = True changed = True
if mime_type and asset.mime_type != mime_type: if mime_type and not asset.mime_type:
asset.mime_type = mime_type asset.mime_type = mime_type
changed = True changed = True
if changed: if changed:
@ -118,7 +118,7 @@ def update_asset_hash_and_mime(
return False return False
if asset_hash is not None: if asset_hash is not None:
asset.hash = asset_hash asset.hash = asset_hash
if mime_type is not None: if mime_type is not None and not asset.mime_type:
asset.mime_type = mime_type asset.mime_type = mime_type
return True return True

View File

@ -10,7 +10,7 @@ from decimal import Decimal
from typing import NamedTuple, Sequence from typing import NamedTuple, Sequence
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import delete, exists, select from sqlalchemy import delete, select
from sqlalchemy.dialects import sqlite from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, noload from sqlalchemy.orm import Session, noload
@ -24,12 +24,14 @@ from app.assets.database.models import (
) )
from app.assets.database.queries.common import ( from app.assets.database.queries.common import (
MAX_BIND_PARAMS, MAX_BIND_PARAMS,
apply_metadata_filter,
apply_tag_filters,
build_prefix_like_conditions, build_prefix_like_conditions,
build_visible_owner_clause, build_visible_owner_clause,
calculate_rows_per_statement, calculate_rows_per_statement,
iter_chunks, iter_chunks,
) )
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags from app.assets.helpers import escape_sql_like_string, get_utc_now
def _check_is_scalar(v): def _check_is_scalar(v):
@ -44,15 +46,6 @@ def _check_is_scalar(v):
def _scalar_to_row(key: str, ordinal: int, value) -> dict: def _scalar_to_row(key: str, ordinal: int, value) -> dict:
"""Convert a scalar value to a typed projection row.""" """Convert a scalar value to a typed projection row."""
if value is None:
return {
"key": key,
"ordinal": ordinal,
"val_str": None,
"val_num": None,
"val_bool": None,
"val_json": None,
}
if isinstance(value, bool): if isinstance(value, bool):
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)} return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
if isinstance(value, (int, float, Decimal)): if isinstance(value, (int, float, Decimal)):
@ -66,96 +59,19 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict:
def convert_metadata_to_rows(key: str, value) -> list[dict]: def convert_metadata_to_rows(key: str, value) -> list[dict]:
"""Turn a metadata key/value into typed projection rows.""" """Turn a metadata key/value into typed projection rows."""
if value is None: if value is None:
return [_scalar_to_row(key, 0, None)] return []
if _check_is_scalar(value): if _check_is_scalar(value):
return [_scalar_to_row(key, 0, value)] return [_scalar_to_row(key, 0, value)]
if isinstance(value, list): if isinstance(value, list):
if all(_check_is_scalar(x) for x in value): if all(_check_is_scalar(x) for x in value):
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)] return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None]
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)] return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None]
return [{"key": key, "ordinal": 0, "val_json": value}] return [{"key": key, "ordinal": 0, "val_json": value}]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetReferenceMeta.val_json.is_(None),
AssetReferenceMeta.val_str.is_(None),
AssetReferenceMeta.val_num.is_(None),
AssetReferenceMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def get_reference_by_id( def get_reference_by_id(
@ -212,6 +128,21 @@ def reference_exists_for_asset_id(
return session.execute(q).first() is not None return session.execute(q).first() is not None
def reference_exists(
session: Session,
reference_id: str,
) -> bool:
"""Return True if a reference with the given ID exists (not soft-deleted)."""
q = (
select(sa.literal(True))
.select_from(AssetReference)
.where(AssetReference.id == reference_id)
.where(AssetReference.deleted_at.is_(None))
.limit(1)
)
return session.execute(q).first() is not None
def insert_reference( def insert_reference(
session: Session, session: Session,
asset_id: str, asset_id: str,
@ -336,8 +267,8 @@ def list_references_page(
escaped, esc = escape_sql_like_string(name_contains) escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
base = _apply_tag_filters(base, include_tags, exclude_tags) base = apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter) base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower() sort = (sort or "created_at").lower()
order = (order or "desc").lower() order = (order or "desc").lower()
@ -366,8 +297,8 @@ def list_references_page(
count_stmt = count_stmt.where( count_stmt = count_stmt.where(
AssetReference.name.ilike(f"%{escaped}%", escape=esc) AssetReference.name.ilike(f"%{escaped}%", escape=esc)
) )
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int(session.execute(count_stmt).scalar_one() or 0) total = int(session.execute(count_stmt).scalar_one() or 0)
refs = session.execute(base).unique().scalars().all() refs = session.execute(base).unique().scalars().all()
@ -379,7 +310,7 @@ def list_references_page(
select(AssetReferenceTag.asset_reference_id, Tag.name) select(AssetReferenceTag.asset_reference_id, Tag.name)
.join(Tag, Tag.name == AssetReferenceTag.tag_name) .join(Tag, Tag.name == AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id.in_(id_list)) .where(AssetReferenceTag.asset_reference_id.in_(id_list))
.order_by(AssetReferenceTag.added_at) .order_by(AssetReferenceTag.tag_name.asc())
) )
for ref_id, tag_name in rows.all(): for ref_id, tag_name in rows.all():
tag_map[ref_id].append(tag_name) tag_map[ref_id].append(tag_name)
@ -492,6 +423,42 @@ def update_reference_updated_at(
) )
def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None:
"""Delete and rebuild AssetReferenceMeta rows from merged system+user metadata.
The merged dict is ``{**system_metadata, **user_metadata}`` so user keys
override system keys of the same name.
"""
session.execute(
delete(AssetReferenceMeta).where(
AssetReferenceMeta.asset_reference_id == ref.id
)
)
session.flush()
merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})}
if not merged:
return
rows: list[AssetReferenceMeta] = []
for k, v in merged.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetReferenceMeta(
asset_reference_id=ref.id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def set_reference_metadata( def set_reference_metadata(
session: Session, session: Session,
reference_id: str, reference_id: str,
@ -505,33 +472,24 @@ def set_reference_metadata(
ref.updated_at = get_utc_now() ref.updated_at = get_utc_now()
session.flush() session.flush()
session.execute( rebuild_metadata_projection(session, ref)
delete(AssetReferenceMeta).where(
AssetReferenceMeta.asset_reference_id == reference_id
) def set_reference_system_metadata(
) session: Session,
reference_id: str,
system_metadata: dict | None = None,
) -> None:
"""Set system_metadata on a reference and rebuild the merged projection."""
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
ref.system_metadata = system_metadata or {}
ref.updated_at = get_utc_now()
session.flush() session.flush()
if not user_metadata: rebuild_metadata_projection(session, ref)
return
rows: list[AssetReferenceMeta] = []
for k, v in user_metadata.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetReferenceMeta(
asset_reference_id=reference_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def delete_reference_by_id( def delete_reference_by_id(
@ -571,19 +529,19 @@ def soft_delete_reference_by_id(
def set_reference_preview( def set_reference_preview(
session: Session, session: Session,
reference_id: str, reference_id: str,
preview_asset_id: str | None = None, preview_reference_id: str | None = None,
) -> None: ) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" """Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
ref = session.get(AssetReference, reference_id) ref = session.get(AssetReference, reference_id)
if not ref: if not ref:
raise ValueError(f"AssetReference {reference_id} not found") raise ValueError(f"AssetReference {reference_id} not found")
if preview_asset_id is None: if preview_reference_id is None:
ref.preview_id = None ref.preview_id = None
else: else:
if not session.get(Asset, preview_asset_id): if not session.get(AssetReference, preview_reference_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found") raise ValueError(f"Preview AssetReference {preview_reference_id} not found")
ref.preview_id = preview_asset_id ref.preview_id = preview_reference_id
ref.updated_at = get_utc_now() ref.updated_at = get_utc_now()
session.flush() session.flush()
@ -609,6 +567,8 @@ def list_references_by_asset_id(
session.execute( session.execute(
select(AssetReference) select(AssetReference)
.where(AssetReference.asset_id == asset_id) .where(AssetReference.asset_id == asset_id)
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
.order_by(AssetReference.id.asc()) .order_by(AssetReference.id.asc())
) )
.scalars() .scalars()
@ -616,6 +576,25 @@ def list_references_by_asset_id(
) )
def list_all_file_paths_by_asset_id(
session: Session,
asset_id: str,
) -> list[str]:
"""Return every file_path for an asset, including soft-deleted/missing refs.
Used for orphan cleanup where all on-disk files must be removed.
"""
return list(
session.execute(
select(AssetReference.file_path)
.where(AssetReference.asset_id == asset_id)
.where(AssetReference.file_path.isnot(None))
)
.scalars()
.all()
)
def upsert_reference( def upsert_reference(
session: Session, session: Session,
asset_id: str, asset_id: str,
@ -855,6 +834,22 @@ def bulk_update_is_missing(
return total return total
def update_is_missing_by_asset_id(
session: Session, asset_id: str, value: bool
) -> int:
"""Set is_missing flag for ALL references belonging to an asset.
Returns: Number of rows updated
"""
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.asset_id == asset_id)
.where(AssetReference.deleted_at.is_(None))
.values(is_missing=value)
)
return result.rowcount
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int: def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
"""Delete references by their IDs. """Delete references by their IDs.

View File

@ -1,12 +1,14 @@
"""Shared utilities for database query modules.""" """Shared utilities for database query modules."""
import os import os
from typing import Iterable from decimal import Decimal
from typing import Iterable, Sequence
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import exists
from app.assets.database.models import AssetReference from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
from app.assets.helpers import escape_sql_like_string from app.assets.helpers import escape_sql_like_string, normalize_tags
MAX_BIND_PARAMS = 800 MAX_BIND_PARAMS = 800
@ -52,3 +54,74 @@ def build_prefix_like_conditions(
escaped, esc = escape_sql_like_string(base) escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
return conds return conds
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
return sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@ -8,12 +8,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.assets.database.models import ( from app.assets.database.models import (
Asset,
AssetReference, AssetReference,
AssetReferenceMeta, AssetReferenceMeta,
AssetReferenceTag, AssetReferenceTag,
Tag, Tag,
) )
from app.assets.database.queries.common import ( from app.assets.database.queries.common import (
apply_metadata_filter,
apply_tag_filters,
build_visible_owner_clause, build_visible_owner_clause,
iter_row_chunks, iter_row_chunks,
) )
@ -72,9 +75,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
tag_name tag_name
for (tag_name,) in ( for (tag_name,) in (
session.execute( session.execute(
select(AssetReferenceTag.tag_name).where( select(AssetReferenceTag.tag_name)
AssetReferenceTag.asset_reference_id == reference_id .where(AssetReferenceTag.asset_reference_id == reference_id)
) .order_by(AssetReferenceTag.tag_name.asc())
) )
).all() ).all()
] ]
@ -117,7 +120,7 @@ def set_reference_tags(
) )
session.flush() session.flush()
return SetTagsResult(added=to_add, removed=to_remove, total=desired) return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
def add_tags_to_reference( def add_tags_to_reference(
@ -272,6 +275,12 @@ def list_tags_with_usage(
.select_from(AssetReferenceTag) .select_from(AssetReferenceTag)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id)) .where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None)) .where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name) .group_by(AssetReferenceTag.tag_name)
.subquery() .subquery()
@ -308,6 +317,12 @@ def list_tags_with_usage(
select(AssetReferenceTag.tag_name) select(AssetReferenceTag.tag_name)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id)) .where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None)) .where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name) .group_by(AssetReferenceTag.tag_name)
) )
@ -320,6 +335,53 @@ def list_tags_with_usage(
return rows_norm, int(total or 0) return rows_norm, int(total or 0)
def list_tag_counts_for_filtered_assets(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
"""Return tag counts for assets matching the given filters.
Uses the same filtering logic as list_references_page but returns
{tag_name: count} instead of paginated references.
"""
# Build a subquery of matching reference IDs
ref_sq = (
select(AssetReference.id)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
ref_sq = ref_sq.subquery()
# Count tags across those references
q = (
select(
AssetReferenceTag.tag_name,
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
.group_by(AssetReferenceTag.tag_name)
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
.limit(limit)
)
rows = session.execute(q).all()
return {tag_name: int(cnt) for tag_name, cnt in rows}
def bulk_insert_tags_and_meta( def bulk_insert_tags_and_meta(
session: Session, session: Session,
tag_rows: list[dict], tag_rows: list[dict],

View File

@ -18,7 +18,7 @@ from app.assets.database.queries import (
mark_references_missing_outside_prefixes, mark_references_missing_outside_prefixes,
reassign_asset_references, reassign_asset_references,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
set_reference_metadata, set_reference_system_metadata,
update_asset_hash_and_mime, update_asset_hash_and_mime,
) )
from app.assets.services.bulk_ingest import ( from app.assets.services.bulk_ingest import (
@ -490,8 +490,8 @@ def enrich_asset(
logging.warning("Failed to hash %s: %s", file_path, e) logging.warning("Failed to hash %s: %s", file_path, e)
if extract_metadata and metadata: if extract_metadata and metadata:
user_metadata = metadata.to_user_metadata() system_metadata = metadata.to_user_metadata()
set_reference_metadata(session, reference_id, user_metadata) set_reference_system_metadata(session, reference_id, system_metadata)
if full_hash: if full_hash:
existing = get_asset_by_hash(session, full_hash) existing = get_asset_by_hash(session, full_hash)

View File

@ -16,10 +16,12 @@ from app.assets.database.queries import (
get_reference_by_id, get_reference_by_id,
get_reference_with_owner_check, get_reference_with_owner_check,
list_references_page, list_references_page,
list_all_file_paths_by_asset_id,
list_references_by_asset_id, list_references_by_asset_id,
set_reference_metadata, set_reference_metadata,
set_reference_preview, set_reference_preview,
set_reference_tags, set_reference_tags,
update_asset_hash_and_mime,
update_reference_access_time, update_reference_access_time,
update_reference_name, update_reference_name,
update_reference_updated_at, update_reference_updated_at,
@ -67,6 +69,8 @@ def update_asset_metadata(
user_metadata: UserMetadata = None, user_metadata: UserMetadata = None,
tag_origin: str = "manual", tag_origin: str = "manual",
owner_id: str = "", owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> AssetDetailResult: ) -> AssetDetailResult:
with create_session() as session: with create_session() as session:
ref = get_reference_with_owner_check(session, reference_id, owner_id) ref = get_reference_with_owner_check(session, reference_id, owner_id)
@ -103,6 +107,21 @@ def update_asset_metadata(
) )
touched = True touched = True
if mime_type is not None:
updated = update_asset_hash_and_mime(
session, asset_id=ref.asset_id, mime_type=mime_type
)
if updated:
touched = True
if preview_id is not None:
set_reference_preview(
session,
reference_id=reference_id,
preview_reference_id=preview_id,
)
touched = True
if touched and user_metadata is None: if touched and user_metadata is None:
update_reference_updated_at(session, reference_id=reference_id) update_reference_updated_at(session, reference_id=reference_id)
@ -159,11 +178,9 @@ def delete_asset_reference(
session.commit() session.commit()
return True return True
# Orphaned asset - delete it and its files # Orphaned asset - gather ALL file paths (including
refs = list_references_by_asset_id(session, asset_id=asset_id) # soft-deleted / missing refs) so their on-disk files get cleaned up.
file_paths = [ file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
]
# Also include the just-deleted file path # Also include the just-deleted file path
if file_path: if file_path:
file_paths.append(file_path) file_paths.append(file_path)
@ -185,7 +202,7 @@ def delete_asset_reference(
def set_asset_preview( def set_asset_preview(
reference_id: str, reference_id: str,
preview_asset_id: str | None = None, preview_reference_id: str | None = None,
owner_id: str = "", owner_id: str = "",
) -> AssetDetailResult: ) -> AssetDetailResult:
with create_session() as session: with create_session() as session:
@ -194,7 +211,7 @@ def set_asset_preview(
set_reference_preview( set_reference_preview(
session, session,
reference_id=reference_id, reference_id=reference_id,
preview_asset_id=preview_asset_id, preview_reference_id=preview_reference_id,
) )
result = fetch_reference_asset_and_tags( result = fetch_reference_asset_and_tags(
@ -263,6 +280,47 @@ def list_assets_page(
return ListAssetsResult(items=items, total=total) return ListAssetsResult(items=items, total=total)
def resolve_hash_to_path(
asset_hash: str,
owner_id: str = "",
) -> DownloadResolutionResult | None:
"""Resolve a blake3 hash to an on-disk file path.
Only references visible to *owner_id* are considered (owner-less
references are always visible).
Returns a DownloadResolutionResult with abs_path, content_type, and
download_name, or None if no asset or live path is found.
"""
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash)
if not asset:
return None
refs = list_references_by_asset_id(session, asset_id=asset.id)
visible = [
r for r in refs
if r.owner_id == "" or r.owner_id == owner_id
]
abs_path = select_best_live_path(visible)
if not abs_path:
return None
display_name = os.path.basename(abs_path)
for ref in visible:
if ref.file_path == abs_path and ref.name:
display_name = ref.name
break
ctype = (
asset.mime_type
or mimetypes.guess_type(display_name)[0]
or "application/octet-stream"
)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=display_name,
)
def resolve_asset_for_download( def resolve_asset_for_download(
reference_id: str, reference_id: str,
owner_id: str = "", owner_id: str = "",

View File

@ -11,13 +11,14 @@ from app.assets.database.queries import (
add_tags_to_reference, add_tags_to_reference,
fetch_reference_and_asset, fetch_reference_and_asset,
get_asset_by_hash, get_asset_by_hash,
get_existing_asset_ids,
get_reference_by_file_path, get_reference_by_file_path,
get_reference_tags, get_reference_tags,
get_or_create_reference, get_or_create_reference,
reference_exists,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
set_reference_metadata, set_reference_metadata,
set_reference_tags, set_reference_tags,
update_asset_hash_and_mime,
upsert_asset, upsert_asset,
upsert_reference, upsert_reference,
validate_tags_exist, validate_tags_exist,
@ -26,6 +27,7 @@ from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import ( from app.assets.services.path_utils import (
compute_relative_filename, compute_relative_filename,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags, resolve_destination_from_tags,
validate_path_within_base, validate_path_within_base,
) )
@ -65,7 +67,7 @@ def _ingest_file_from_path(
with create_session() as session: with create_session() as session:
if preview_id: if preview_id:
if preview_id not in get_existing_asset_ids(session, [preview_id]): if not reference_exists(session, preview_id):
preview_id = None preview_id = None
asset, asset_created, asset_updated = upsert_asset( asset, asset_created, asset_updated = upsert_asset(
@ -135,6 +137,8 @@ def _register_existing_asset(
tags: list[str] | None = None, tags: list[str] | None = None,
tag_origin: str = "manual", tag_origin: str = "manual",
owner_id: str = "", owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> RegisterAssetResult: ) -> RegisterAssetResult:
user_metadata = user_metadata or {} user_metadata = user_metadata or {}
@ -143,14 +147,25 @@ def _register_existing_asset(
if not asset: if not asset:
raise ValueError(f"No asset with hash {asset_hash}") raise ValueError(f"No asset with hash {asset_hash}")
if mime_type and not asset.mime_type:
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
if preview_id:
if not reference_exists(session, preview_id):
preview_id = None
ref, ref_created = get_or_create_reference( ref, ref_created = get_or_create_reference(
session, session,
asset_id=asset.id, asset_id=asset.id,
owner_id=owner_id, owner_id=owner_id,
name=name, name=name,
preview_id=preview_id,
) )
if not ref_created: if not ref_created:
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
tag_names = get_reference_tags(session, reference_id=ref.id) tag_names = get_reference_tags(session, reference_id=ref.id)
result = RegisterAssetResult( result = RegisterAssetResult(
ref=extract_reference_data(ref), ref=extract_reference_data(ref),
@ -242,6 +257,8 @@ def upload_from_temp_path(
client_filename: str | None = None, client_filename: str | None = None,
owner_id: str = "", owner_id: str = "",
expected_hash: str | None = None, expected_hash: str | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult: ) -> UploadResult:
try: try:
digest, _ = hashing.compute_blake3_hash(temp_path) digest, _ = hashing.compute_blake3_hash(temp_path)
@ -270,6 +287,8 @@ def upload_from_temp_path(
tags=tags or [], tags=tags or [],
tag_origin="manual", tag_origin="manual",
owner_id=owner_id, owner_id=owner_id,
mime_type=mime_type,
preview_id=preview_id,
) )
return UploadResult( return UploadResult(
ref=result.ref, ref=result.ref,
@ -291,7 +310,7 @@ def upload_from_temp_path(
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir) validate_path_within_base(dest_abs, base_dir)
content_type = ( content_type = mime_type or (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0] or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream" or "application/octet-stream"
@ -315,7 +334,7 @@ def upload_from_temp_path(
mime_type=content_type, mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest), info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id, owner_id=owner_id,
preview_id=None, preview_id=preview_id,
user_metadata=user_metadata or {}, user_metadata=user_metadata or {},
tags=tags, tags=tags,
tag_origin="manual", tag_origin="manual",
@ -342,30 +361,99 @@ def upload_from_temp_path(
) )
def register_file_in_place(
abs_path: str,
name: str,
tags: list[str],
owner_id: str = "",
mime_type: str | None = None,
) -> UploadResult:
"""Register an already-saved file in the asset database without moving it.
Tags are derived from the filesystem path (root category + subfolder names),
merged with any caller-provided tags, matching the behavior of the scanner.
If the path is not under a known root, only the caller-provided tags are used.
"""
try:
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
except ValueError:
path_tags = []
merged_tags = normalize_tags([*path_tags, *tags])
try:
digest, _ = hashing.compute_blake3_hash(abs_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash file: {e}")
asset_hash = "blake3:" + digest
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
content_type = mime_type or (
mimetypes.guess_type(abs_path, strict=False)[0]
or "application/octet-stream"
)
ingest_result = _ingest_file_from_path(
abs_path=abs_path,
asset_hash=asset_hash,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name, fallback=digest),
owner_id=owner_id,
tags=merged_tags,
tag_origin="upload",
require_existing_tags=False,
)
reference_id = ingest_result.reference_id
if not reference_id:
raise RuntimeError("failed to create asset reference")
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
ref, asset = pair
tag_names = get_reference_tags(session, reference_id=ref.id)
return UploadResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def create_from_hash( def create_from_hash(
hash_str: str, hash_str: str,
name: str, name: str,
tags: list[str] | None = None, tags: list[str] | None = None,
user_metadata: dict | None = None, user_metadata: dict | None = None,
owner_id: str = "", owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult | None: ) -> UploadResult | None:
canonical = hash_str.strip().lower() canonical = hash_str.strip().lower()
with create_session() as session: try:
asset = get_asset_by_hash(session, asset_hash=canonical) result = _register_existing_asset(
if not asset: asset_hash=canonical,
return None name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
result = _register_existing_asset( ),
asset_hash=canonical, user_metadata=user_metadata or {},
name=_sanitize_filename( tags=tags or [],
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical tag_origin="manual",
), owner_id=owner_id,
user_metadata=user_metadata or {}, mime_type=mime_type,
tags=tags or [], preview_id=preview_id,
tag_origin="manual", )
owner_id=owner_id, except ValueError:
) logging.warning("create_from_hash: no asset found for hash %s", canonical)
return None
return UploadResult( return UploadResult(
ref=result.ref, ref=result.ref,

View File

@ -25,7 +25,9 @@ class ReferenceData:
preview_id: str | None preview_id: str | None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
last_access_time: datetime | None system_metadata: dict[str, Any] | None = None
job_id: str | None = None
last_access_time: datetime | None = None
@dataclass(frozen=True) @dataclass(frozen=True)
@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
file_path=ref.file_path, file_path=ref.file_path,
user_metadata=ref.user_metadata, user_metadata=ref.user_metadata,
preview_id=ref.preview_id, preview_id=ref.preview_id,
system_metadata=ref.system_metadata,
job_id=ref.job_id,
created_at=ref.created_at, created_at=ref.created_at,
updated_at=ref.updated_at, updated_at=ref.updated_at,
last_access_time=ref.last_access_time, last_access_time=ref.last_access_time,

View File

@ -1,3 +1,5 @@
from typing import Sequence
from app.assets.database.queries import ( from app.assets.database.queries import (
AddTagsResult, AddTagsResult,
RemoveTagsResult, RemoveTagsResult,
@ -6,6 +8,7 @@ from app.assets.database.queries import (
list_tags_with_usage, list_tags_with_usage,
remove_tags_from_reference, remove_tags_from_reference,
) )
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
from app.assets.services.schemas import TagUsage from app.assets.services.schemas import TagUsage
from app.database.db import create_session from app.database.db import create_session
@ -73,3 +76,23 @@ def list_tags(
) )
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
def list_tag_histogram(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
with create_session() as session:
return list_tag_counts_for_filtered_assets(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
)

View File

@ -1,9 +1,18 @@
from typing import Any from typing import Any
from datetime import datetime from datetime import datetime
from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
NAMING_CONVENTION = {
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass metadata = MetaData(naming_convention=NAMING_CONVENTION)
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys() fields = obj.__table__.columns.keys()

View File

@ -6,6 +6,7 @@ import uuid
import glob import glob
import shutil import shutil
import logging import logging
import tempfile
from aiohttp import web from aiohttp import web
from urllib import parse from urllib import parse
from comfy.cli_args import args from comfy.cli_args import args
@ -377,8 +378,15 @@ class UserManager():
try: try:
body = await request.read() body = await request.read()
with open(path, "wb") as f: dir_name = os.path.dirname(path)
f.write(body) fd, tmp_path = tempfile.mkstemp(dir=dir_name)
try:
with os.fdopen(fd, "wb") as f:
f.write(body)
os.replace(tmp_path, path)
except:
os.unlink(tmp_path)
raise
except OSError as e: except OSError as e:
logging.warning(f"Error saving file '{path}': {e}") logging.warning(f"Error saving file '{path}': {e}")
return web.Response( return web.Response(

View File

@ -83,6 +83,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.") fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
@ -147,6 +149,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@ -260,4 +263,6 @@ else:
args.fast = set(args.fast) args.fast = set(args.fast)
def enables_dynamic_vram(): def enables_dynamic_vram():
if args.enable_dynamic_vram:
return True
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu

View File

@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
output_block[i:i + slice_size].copy_(block) output_block[i:i + slice_size].copy_(block)
return output_fp4, to_blocked(output_block, flatten=False) return output_fp4, to_blocked(output_block, flatten=False)
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
def roundup(x_val, multiple):
return ((x_val + multiple - 1) // multiple) * multiple
if pad_32x:
rows, cols = x.shape
padded_rows = roundup(rows, 32)
padded_cols = roundup(cols, 32)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
F8_E4M3_MAX = 448.0
E8M0_BIAS = 127
BLOCK_SIZE = 32
rows, cols = x.shape
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
# E8M0 block scales (power-of-2 exponents)
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
block_scales_e8m0 = exp_biased.to(torch.uint8)
zero_mask = (max_abs == 0)
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
# Scale per-block then stochastic round
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)

View File

@ -343,6 +343,7 @@ class CrossAttention(nn.Module):
k.reshape(b, s2, self.num_heads * self.head_dim), k.reshape(b, s2, self.num_heads * self.head_dim),
v, v,
heads=self.num_heads, heads=self.num_heads,
low_precision_attention=False,
) )
out = self.out_proj(x) out = self.out_proj(x)
@ -412,6 +413,7 @@ class Attention(nn.Module):
key.reshape(B, N, self.num_heads * self.head_dim), key.reshape(B, N, self.num_heads * self.head_dim),
value, value,
heads=self.num_heads, heads=self.num_heads,
low_precision_attention=False,
) )
x = self.out_proj(x) x = self.out_proj(x)

View File

@ -11,6 +11,7 @@ from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops import comfy.ops
import comfy.model_management
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -536,7 +537,7 @@ class Decoder(nn.Module):
mark_conv3d_ended(self.conv_out) mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal) sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0: if sample is not None and sample.shape[2] > 0:
output.append(sample) output.append(sample.to(comfy.model_management.intermediate_device()))
return return
up_block = self.up_blocks[idx] up_block = self.up_blocks[idx]

View File

@ -1,9 +1,68 @@
import math import math
import ctypes
import threading
import dataclasses
import torch import torch
from typing import NamedTuple from typing import NamedTuple
from comfy.quant_ops import QuantizedTensor from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple):
file_ref: object
thread_id: int
offset: int
size: int
def read_tensor_file_slice_into(tensor, destination):
if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
return False
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
return False
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
return True
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
if info is None:
return False
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size):
return False
if info.size == 0:
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
try:
file_obj.seek(info.offset)
done = 0
while done < info.size:
try:
n = file_obj.readinto(view[done:])
except OSError:
return False
if n <= 0:
return False
done += n
return True
finally:
view.release()
class TensorGeometry(NamedTuple): class TensorGeometry(NamedTuple):
shape: any shape: any
dtype: torch.dtype dtype: torch.dtype

View File

@ -400,7 +400,7 @@ try:
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0): if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1200", "gfx1201"]): if any((a in arch) for a in ["gfx1200", "gfx1201"]):
@ -505,6 +505,28 @@ def module_size(module):
module_mem += t.nbytes module_mem += t.nbytes
return module_mem return module_mem
def module_mmap_residency(module, free=False):
mmap_touched_mem = 0
module_mem = 0
bounced_mmaps = set()
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nbytes
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
continue
mmap_touched_mem += t.nbytes
if not free:
continue
storage._comfy_tensor_mmap_touched = False
mmap_obj = storage._comfy_tensor_mmap_refs[0]
if mmap_obj in bounced_mmaps:
continue
mmap_obj.bounce()
bounced_mmaps.add(mmap_obj)
return mmap_touched_mem, module_mem
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model):
self._set_model(model) self._set_model(model)
@ -519,6 +541,7 @@ class LoadedModel:
if model.parent is not None: if model.parent is not None:
self._parent_model = weakref.ref(model.parent) self._parent_model = weakref.ref(model.parent)
self._patcher_finalizer = weakref.finalize(model, self._switch_parent) self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
self._patcher_finalizer.atexit = False
def _switch_parent(self): def _switch_parent(self):
model = self._parent_model() model = self._parent_model()
@ -532,6 +555,9 @@ class LoadedModel:
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)
def model_loaded_memory(self): def model_loaded_memory(self):
return self.model.loaded_size() return self.model.loaded_size()
@ -562,6 +588,7 @@ class LoadedModel:
self.real_model = weakref.ref(real_model) self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models) self.model_finalizer = weakref.finalize(real_model, cleanup_models)
self.model_finalizer.atexit = False
return real_model return real_model
def should_reload_model(self, force_patch_weights=False): def should_reload_model(self, force_patch_weights=False):
@ -633,7 +660,7 @@ def extra_reserved_memory():
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc() cleanup_models_gc()
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
@ -646,13 +673,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False shift_model.currently_used = False
for x in sorted(can_unload): can_unload_sorted = sorted(can_unload)
for x in can_unload_sorted:
i = x[-1] i = x[-1]
memory_to_free = 1e32 memory_to_free = 1e32
ram_to_free = 1e32 pins_to_free = 1e32
if not DISABLE_SMART_MEMORY: if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device) memory_to_free = memory_required - get_free_memory(device)
ram_to_free = ram_required - get_free_ram() pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic: if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models #don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand. #as that works on-demand.
@ -661,9 +689,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i) unloaded_model.append(i)
if ram_to_free > 0: if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i)) unloaded_models.append(current_loaded_models.pop(i))
@ -729,17 +766,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
total_memory_required = {} total_memory_required = {}
total_pins_required = {}
total_ram_required = {} total_ram_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) device = loaded_model.device
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
#want to do. resident_memory, model_memory = loaded_model.model_mmap_residency()
#FIXME: This should subtract off the to_load current pin consumption. pinned_memory = loaded_model.model.pinned_memory_size()
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2 #FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device]) free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic,
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
@ -1005,6 +1052,12 @@ def intermediate_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def intermediate_dtype():
if args.fp16_intermediates:
return torch.float16
else:
return torch.float32
def vae_device(): def vae_device():
if args.cpu_vae: if args.cpu_vae:
return torch.device("cpu") return torch.device("cpu")
@ -1225,6 +1278,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
dest_view = dest_views.pop(0) dest_view = dest_views.pop(0)
if tensor is None: if tensor is None:
continue continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"):
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking) dest_view.copy_(tensor, non_blocking=non_blocking)
@ -1662,6 +1720,19 @@ def supports_nvfp4_compute(device=None):
return True return True
def supports_mxfp8_compute(device=None):
if not is_nvidia():
return False
if torch_version_numeric < (2, 10):
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support(): def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older # TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7): if torch_version_numeric < (2, 7):

View File

@ -297,6 +297,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model) self.size = comfy.model_management.module_size(self.model)
return self.size return self.size
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def get_ram_usage(self): def get_ram_usage(self):
return self.model_size() return self.model_size()
@ -1063,6 +1066,10 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used return self.model.model_loaded_weight_memory - current_used
def pinned_memory_size(self):
# Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload):
pass pass
@ -1653,6 +1660,16 @@ class ModelPatcherDynamic(ModelPatcher):
return freed return freed
def pinned_memory_size(self):
total = 0
loading = self._load_list(for_dynamic=True)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device) loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading: for x in loading:

View File

@ -306,10 +306,40 @@ class CastWeightBiasOp:
bias_function = [] bias_function = []
class disable_weight_init: class disable_weight_init:
@staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape,
bias_shape=None):
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k, v in state_dict.items():
key = k[prefix_len:]
if key == "weight":
if not assign_to_params_buffers:
v = v.clone()
module.weight = torch.nn.Parameter(v, requires_grad=False)
elif bias_shape is not None and key == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
module.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
if module.weight is None:
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
missing_keys.append(prefix + "weight")
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
missing_keys.append(prefix + "bias")
class Linear(torch.nn.Linear, CastWeightBiasOp): class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: # don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
super().__init__(in_features, out_features, bias, device, dtype) super().__init__(in_features, out_features, bias, device, dtype)
return return
@ -330,32 +360,21 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs): strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs) missing_keys, unexpected_keys, error_msgs)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) disable_weight_init._lazy_load_from_state_dict(
prefix_len = len(prefix) self,
for k,v in state_dict.items(): state_dict,
if k[prefix_len:] == "weight": prefix,
if not assign_to_params_buffers: local_metadata,
v = v.clone() missing_keys,
self.weight = torch.nn.Parameter(v, requires_grad=False) unexpected_keys,
elif k[prefix_len:] == "bias" and v is not None: weight_shape=(self.in_features, self.out_features),
if not assign_to_params_buffers: bias_shape=(self.out_features,),
v = v.clone() )
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
#Reconcile default construction of the weight if its missing.
if self.weight is None:
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")
if self.bias is None and self.comfy_need_lazy_init_bias:
v = torch.zeros(self.out_features,)
self.bias = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"bias")
def reset_parameters(self): def reset_parameters(self):
@ -547,6 +566,53 @@ class disable_weight_init:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp): class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
_freeze=False, device=None, dtype=None):
# don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight,
_freeze, device, dtype)
return
torch.nn.Module.__init__(self)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# Keep shape/dtype visible for module introspection without reserving storage.
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
self.weight = torch.nn.Parameter(
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
requires_grad=False,
)
self.bias = None
self.weight_comfy_model_dtype = dtype
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.num_embeddings, self.embedding_dim),
)
def reset_parameters(self): def reset_parameters(self):
self.bias = None self.bias = None
return None return None
@ -710,6 +776,71 @@ from .quant_ops import (
) )
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""
@staticmethod
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input (same as inference path)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
q_input = inp
w = weight.detach() if weight.requires_grad else weight
b = bias.detach() if bias is not None and bias.requires_grad else bias
output = torch.nn.functional.linear(q_input, w, b)
# Restore original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
ctx.save_for_backward(input_float, weight)
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
weight_f = weight.dequantize().to(compute_dtype)
else:
weight_f = weight.to(compute_dtype)
# grad_input = grad_output @ weight
grad_input = torch.mm(grad_2d, weight_f)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)
# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)
return grad_input, grad_weight, grad_bias, None, None, None
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast): class MixedPrecisionOps(manual_cast):
_quant_config = quant_config _quant_config = quant_config
@ -801,6 +932,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
orig_shape=(self.out_features, self.in_features), orig_shape=(self.out_features, self.in_features),
) )
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4": elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
@ -888,10 +1035,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
#If cast needs to apply lora, it should be done in the compute dtype #If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype compute_dtype = input.dtype
if (getattr(self, 'layout_type', None) is not None and _use_quantized = (
getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0): len(self.weight_function) == 0 and len(self.bias_function) == 0
)
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=True
)
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
output = QuantLinearFunc.apply(
input, weight, bias, self.layout_type, scale, compute_dtype
)
uncast_bias_weight(self, weight, bias, offload_stream)
return output
# Inference path (unchanged)
if _use_quantized:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@ -939,7 +1113,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
for key, param in self._parameters.items(): for key, param in self._parameters.items():
if param is None: if param is None:
continue continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) p = fn(param)
if p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items(): for key, buf in self._buffers.items():
if buf is not None: if buf is not None:
self._buffers[key] = fn(buf) self._buffers[key] = fn(buf)
@ -950,12 +1127,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations") logging.info("Using mixed precision operations")
disabled = set() disabled = set()
if not nvfp4_compute: if not nvfp4_compute:
disabled.add("nvfp4") disabled.add("nvfp4")
if not mxfp8_compute:
disabled.add("mxfp8")
if not fp8_compute: if not fp8_compute:
disabled.add("float8_e4m3fn") disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2") disabled.add("float8_e5m2")

View File

@ -1,6 +1,7 @@
import torch
import comfy.model_management import comfy.model_management
import comfy.memory_management import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
from comfy.cli_args import args from comfy.cli_args import args
@ -12,18 +13,31 @@ def pin_memory(module):
return return
#FIXME: This is a RAM cache trigger event #FIXME: This is a RAM cache trigger event
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
pin = torch.empty((size,), dtype=torch.uint8)
if comfy.model_management.pin_memory(pin): if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
module._pin = pin
else:
module.pin_failed = True module.pin_failed = True
return False return False
try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
except RuntimeError:
module.pin_failed = True
return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
module._pin_hostbuf = hostbuf
comfy.model_management.TOTAL_PINNED_MEMORY += size
return True return True
def unpin_memory(module): def unpin_memory(module):
if get_pin(module) is None: if get_pin(module) is None:
return 0 return 0
size = module._pin.numel() * module._pin.element_size() size = module._pin.numel() * module._pin.element_size()
comfy.model_management.unpin_memory(module._pin)
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
del module._pin del module._pin
del module._pin_hostbuf
return size return size

View File

@ -43,6 +43,18 @@ except ImportError as e:
def get_layout_class(name): def get_layout_class(name):
return None return None
_CK_MXFP8_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
_CK_MXFP8_AVAILABLE = True
except ImportError:
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
if not _CK_MXFP8_AVAILABLE:
class _CKMxfp8Layout:
pass
import comfy.float import comfy.float
# ============================================================================== # ==============================================================================
@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params return qdata, params
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
params = cls.Params(
scale=block_scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
)
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout): class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod @classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
QUANT_ALGOS = { QUANT_ALGOS = {
"float8_e4m3fn": { "float8_e4m3fn": {
@ -157,6 +196,14 @@ QUANT_ALGOS = {
}, },
} }
if _CK_MXFP8_AVAILABLE:
QUANT_ALGOS["mxfp8"] = {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
"group_size": 32,
}
# ============================================================================== # ==============================================================================
# Re-exports for backward compatibility # Re-exports for backward compatibility

View File

@ -871,13 +871,16 @@ class VAE:
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value) pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels return pixels
def vae_output_dtype(self):
return model_management.intermediate_dtype()
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
output = self.process_output( output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@ -887,16 +890,16 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32): def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3: if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
else: else:
og_shape = samples.shape og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@ -905,7 +908,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@ -914,7 +917,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1: if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
out_channels = self.latent_channels out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio upscale_amount = 1 / self.downscale_ratio
else: else:
@ -923,7 +926,7 @@ class VAE:
tile_x = tile_x // extra_channel_size tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1: if self.latent_dim == 1:
@ -932,7 +935,7 @@ class VAE:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1) return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}): def decode(self, samples_in, vae_options={}):
@ -950,9 +953,9 @@ class VAE:
for x in range(0, samples_in.shape[0], batch_number): for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).to(dtype=self.vae_output_dtype()))
if pixel_samples is None: if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number] = out pixel_samples[x:x+batch_number] = out
except Exception as e: except Exception as e:
model_management.raise_non_oom(e) model_management.raise_non_oom(e)
@ -1025,9 +1028,9 @@ class VAE:
samples = None samples = None
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device) pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float() out = self.first_stage_model.encode(pixels_in).to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None: if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out samples[x:x + batch_number] = out
except Exception as e: except Exception as e:

View File

@ -20,6 +20,8 @@
import torch import torch
import math import math
import struct import struct
import ctypes
import os
import comfy.memory_management import comfy.memory_management
import safetensors.torch import safetensors.torch
import numpy as np import numpy as np
@ -32,7 +34,7 @@ from einops import rearrange
from comfy.cli_args import args from comfy.cli_args import args
import json import json
import time import time
import mmap import threading
import warnings import warnings
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
@ -81,14 +83,17 @@ _TYPES = {
} }
def load_safetensors(ckpt): def load_safetensors(ckpt):
f = open(ckpt, "rb") import comfy_aimdo.model_mmap
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
mv = memoryview(mapping)
header_size = struct.unpack("<Q", mapping[:8])[0] f = open(ckpt, "rb", buffering=0)
header = json.loads(mapping[8:8+header_size].decode("utf-8")) model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
mv = mv[8 + header_size:] header_size = struct.unpack("<Q", mv[:8])[0]
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
mv = mv[(data_base_offset := 8 + header_size):]
sd = {} sd = {}
for name, info in header.items(): for name, info in header.items():
@ -102,7 +107,14 @@ def load_safetensors(ckpt):
with warnings.catch_warnings(): with warnings.catch_warnings():
#We are working with read-only RAM by design #We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable") warnings.filterwarnings("ignore", message="The given buffer is not writable")
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"]) tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
storage = tensor.untyped_storage()
setattr(storage,
"_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor
return sd, header.get("__metadata__", {}), return sd, header.get("__metadata__", {}),
@ -885,6 +897,10 @@ def set_attr(obj, attr, value):
return prev return prev
def set_attr_param(obj, attr, value): def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
if (not torch.is_inference_mode_enabled()) and value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value): def set_attr_buffer(obj, attr, value):

View File

@ -1459,6 +1459,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
node_id="KlingOmniProEditVideoNode", node_id="KlingOmniProEditVideoNode",
display_name="Kling 3.0 Omni Edit Video", display_name="Kling 3.0 Omni Edit Video",
category="api node/video/Kling", category="api node/video/Kling",
essentials_category="Video Generation",
description="Edit an existing video with the latest model from Kling.", description="Edit an existing video with the latest model from Kling.",
inputs=[ inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),

View File

@ -833,6 +833,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
node_id="RecraftVectorizeImageNode", node_id="RecraftVectorizeImageNode",
display_name="Recraft Vectorize Image", display_name="Recraft Vectorize Image",
category="api node/image/Recraft", category="api node/image/Recraft",
essentials_category="Image Tools",
description="Generates SVG synchronously from an input image.", description="Generates SVG synchronously from an input image.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),

View File

@ -19,6 +19,7 @@ class EmptyLatentAudio(IO.ComfyNode):
node_id="EmptyLatentAudio", node_id="EmptyLatentAudio",
display_name="Empty Latent Audio", display_name="Empty Latent Audio",
category="latent/audio", category="latent/audio",
essentials_category="Audio",
inputs=[ inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
IO.Int.Input( IO.Int.Input(
@ -185,6 +186,7 @@ class SaveAudioMP3(IO.ComfyNode):
search_aliases=["export mp3"], search_aliases=["export mp3"],
display_name="Save Audio (MP3)", display_name="Save Audio (MP3)",
category="audio", category="audio",
essentials_category="Audio",
inputs=[ inputs=[
IO.Audio.Input("audio"), IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"), IO.String.Input("filename_prefix", default="audio/ComfyUI"),

View File

@ -14,6 +14,7 @@ class ImageCompare(IO.ComfyNode):
display_name="Image Compare", display_name="Image Compare",
description="Compares two images side by side with a slider.", description="Compares two images side by side with a slider.",
category="image", category="image",
essentials_category="Image Tools",
is_experimental=True, is_experimental=True,
is_output_node=True, is_output_node=True,
inputs=[ inputs=[

View File

@ -58,6 +58,7 @@ class ImageCropV2(IO.ComfyNode):
search_aliases=["trim"], search_aliases=["trim"],
display_name="Image Crop", display_name="Image Crop",
category="image/transform", category="image/transform",
essentials_category="Image Tools",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"), IO.BoundingBox.Input("crop_region", component="ImageCrop"),

View File

@ -21,6 +21,7 @@ class Blend(io.ComfyNode):
node_id="ImageBlend", node_id="ImageBlend",
display_name="Image Blend", display_name="Image Blend",
category="image/postprocessing", category="image/postprocessing",
essentials_category="Image Tools",
inputs=[ inputs=[
io.Image.Input("image1"), io.Image.Input("image1"),
io.Image.Input("image2"), io.Image.Input("image2"),

View File

@ -15,6 +15,7 @@ import comfy.sampler_helpers
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy_extras.nodes_custom_sampler import comfy_extras.nodes_custom_sampler
import folder_paths import folder_paths
import node_helpers import node_helpers
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
training_dtype=torch.bfloat16, training_dtype=torch.bfloat16,
real_dataset=None, real_dataset=None,
bucket_latents=None, bucket_latents=None,
use_grad_scaler=False,
): ):
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = optimizer self.optimizer = optimizer
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
self.bucket_latents: list[torch.Tensor] | None = ( self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi) bucket_latents # list of (Bi, C, Hi, Wi)
) )
# GradScaler for fp16 training
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
# Precompute bucket offsets and weights for sampling # Precompute bucket offsets and weights for sampling
if bucket_latents is not None: if bucket_latents is not None:
self._init_bucket_data(bucket_latents) self._init_bucket_data(bucket_latents)
@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
batch_sigmas.requires_grad_(True), batch_sigmas.requires_grad_(True),
**batch_extra_args, **batch_extra_args,
) )
loss = self.loss_fn(x0_pred, x0) loss = self.loss_fn(x0_pred.float(), x0.float())
if bwd: if bwd:
bwd_loss = loss / self.grad_acc bwd_loss = loss / self.grad_acc
bwd_loss.backward() if self.grad_scaler is not None:
self.grad_scaler.scale(bwd_loss).backward()
else:
bwd_loss.backward()
return loss return loss
def _generate_batch_sigmas(self, model_wrap, batch_size, device): def _generate_batch_sigmas(self, model_wrap, batch_size, device):
@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
) )
total_loss += loss total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies) total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward() if self.grad_scaler is not None:
self.grad_scaler.scale(total_loss).backward()
else:
total_loss.backward()
if self.loss_callback: if self.loss_callback:
self.loss_callback(total_loss.item()) self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0: if (i + 1) % self.grad_acc == 0:
if self.grad_scaler is not None:
self.grad_scaler.unscale_(self.optimizer)
for param_groups in self.optimizer.param_groups: for param_groups in self.optimizer.param_groups:
for param in param_groups["params"]: for param in param_groups["params"]:
if param.grad is None: if param.grad is None:
continue continue
param.grad.data = param.grad.data.to(param.data.dtype) param.grad.data = param.grad.data.to(param.data.dtype)
self.optimizer.step() if self.grad_scaler is not None:
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
ui_pbar.update(1) ui_pbar.update(1)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
), ),
io.Combo.Input( io.Combo.Input(
"training_dtype", "training_dtype",
options=["bf16", "fp32"], options=["bf16", "fp32", "none"],
default="bf16", default="bf16",
tooltip="The dtype to use for training.", tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
), ),
io.Combo.Input( io.Combo.Input(
"lora_dtype", "lora_dtype",
@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
io.Boolean.Input( io.Boolean.Input(
"offloading", "offloading",
default=False, default=False,
tooltip="Offload the Model to RAM. Requires Bypass Mode.", tooltip="Offload model weights to CPU during training to save GPU memory.",
), ),
io.Combo.Input( io.Combo.Input(
"existing_lora", "existing_lora",
@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
# Setup model and dtype # Setup model and dtype
mp = model.clone() mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype) use_grad_scaler = False
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
if mp.is_dynamic():
if not bypass_mode:
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
bypass_mode = True
offloading = True
elif offloading:
if not bypass_mode:
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
# Prepare latents and compute counts # Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count( latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode latents, latents_dtype, bucket_mode
) )
# Validate and expand conditioning # Validate and expand conditioning
@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed, seed=seed,
training_dtype=dtype, training_dtype=dtype,
bucket_latents=latents, bucket_latents=latents,
use_grad_scaler=use_grad_scaler,
) )
else: else:
train_sampler = TrainSampler( train_sampler = TrainSampler(
@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed, seed=seed,
training_dtype=dtype, training_dtype=dtype,
real_dataset=latents if multi_res else None, real_dataset=latents if multi_res else None,
use_grad_scaler=use_grad_scaler,
) )
# Setup guider # Setup guider
@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
io.Int.Input( io.Int.Input(
"steps", "steps",
optional=True, optional=True,
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
), ),
], ],
outputs=[], outputs=[],

View File

@ -206,8 +206,8 @@ import hook_breaker_ac10a0
import comfy.memory_management import comfy.memory_management
import comfy.model_patcher import comfy.model_patcher
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl(): if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
if comfy.model_management.torch_version_numeric < (2, 8): if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if args.verbose == 'DEBUG': if args.verbose == 'DEBUG':

View File

@ -1 +1 @@
comfyui_manager==4.1b4 comfyui_manager==4.1b5

View File

@ -32,7 +32,7 @@ async def cache_control(
) )
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point: if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
response.headers.setdefault("Cache-Control", "no-cache") response.headers.setdefault("Cache-Control", "no-store")
return response return response
# Early return for non-image files - no cache headers needed # Early return for non-image files - no cache headers needed

View File

@ -81,6 +81,7 @@ class CLIPTextEncode(ComfyNodeABC):
class ConditioningCombine: class ConditioningCombine:
ESSENTIALS_CATEGORY = "Image Generation"
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}} return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
@ -1211,9 +1212,6 @@ class GLIGENTextBoxApply:
return (c, ) return (c, )
class EmptyLatentImage: class EmptyLatentImage:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -1232,7 +1230,7 @@ class EmptyLatentImage:
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"] SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
def generate(self, width, height, batch_size=1): def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return ({"samples": latent, "downscale_ratio_spacial": 8}, ) return ({"samples": latent, "downscale_ratio_spacial": 8}, )
@ -1724,6 +1722,8 @@ class LoadImage:
output_masks = [] output_masks = []
w, h = None, None w, h = None, None
dtype = comfy.model_management.intermediate_dtype()
for i in ImageSequence.Iterator(img): for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i) i = node_helpers.pillow(ImageOps.exif_transpose, i)
@ -1748,8 +1748,8 @@ class LoadImage:
mask = 1. - torch.from_numpy(mask) mask = 1. - torch.from_numpy(mask)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image) output_images.append(image.to(dtype=dtype))
output_masks.append(mask.unsqueeze(0)) output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
if img.format == "MPO": if img.format == "MPO":
break # ignore all frames except the first one for MPO format break # ignore all frames except the first one for MPO format
@ -1779,6 +1779,7 @@ class LoadImage:
return True return True
class LoadImageMask: class LoadImageMask:
ESSENTIALS_CATEGORY = "Image Tools"
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
_color_channels = ["alpha", "red", "green", "blue"] _color_channels = ["alpha", "red", "green", "blue"]
@ -1887,6 +1888,7 @@ class ImageScale:
return (s,) return (s,)
class ImageScaleBy: class ImageScaleBy:
ESSENTIALS_CATEGORY = "Image Tools"
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@classmethod @classmethod

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.41.19 comfyui-frontend-package==1.41.20
comfyui-workflow-templates==0.9.21 comfyui-workflow-templates==0.9.26
comfyui-embedded-docs==0.4.3 comfyui-embedded-docs==0.4.3
torch torch
torchsde torchsde
@ -23,7 +23,7 @@ SQLAlchemy
filelock filelock
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.8 comfy-kitchen>=0.2.8
comfy-aimdo>=0.2.10 comfy-aimdo>=0.2.12
requests requests
simpleeval>=1.0.0 simpleeval>=1.0.0
blake3 blake3

View File

@ -35,6 +35,8 @@ from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes from app.assets.api.routes import register_assets_routes
from app.assets.services.ingest import register_file_in_place
from app.assets.services.asset_management import resolve_hash_to_path
from app.user_manager import UserManager from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
@ -310,7 +312,7 @@ class PromptServer():
@routes.get("/") @routes.get("/")
async def get_root(request): async def get_root(request):
response = web.FileResponse(os.path.join(self.web_root, "index.html")) response = web.FileResponse(os.path.join(self.web_root, "index.html"))
response.headers['Cache-Control'] = 'no-cache' response.headers['Cache-Control'] = 'no-store, must-revalidate'
response.headers["Pragma"] = "no-cache" response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0" response.headers["Expires"] = "0"
return response return response
@ -419,7 +421,24 @@ class PromptServer():
with open(filepath, "wb") as f: with open(filepath, "wb") as f:
f.write(image.file.read()) f.write(image.file.read())
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
if args.enable_assets:
try:
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
resp["asset"] = {
"id": result.ref.id,
"name": result.ref.name,
"asset_hash": result.asset.hash,
"size": result.asset.size_bytes,
"mime_type": result.asset.mime_type,
"tags": result.tags,
}
except Exception:
logging.warning("Failed to register uploaded image as asset", exc_info=True)
return web.json_response(resp)
else: else:
return web.Response(status=400) return web.Response(status=400)
@ -479,30 +498,43 @@ class PromptServer():
async def view_image(request): async def view_image(request):
if "filename" in request.rel_url.query: if "filename" in request.rel_url.query:
filename = request.rel_url.query["filename"] filename = request.rel_url.query["filename"]
filename, output_dir = folder_paths.annotated_filepath(filename)
if not filename: # The frontend's LoadImage combo widget uses asset_hash values
return web.Response(status=400) # (e.g. "blake3:...") as widget values. When litegraph renders the
# node preview, it constructs /view?filename=<asset_hash>, so this
# endpoint must resolve blake3 hashes to their on-disk file paths.
if filename.startswith("blake3:"):
owner_id = self.user_manager.get_request_user_id(request)
result = resolve_hash_to_path(filename, owner_id=owner_id)
if result is None:
return web.Response(status=404)
file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type
else:
resolved_content_type = None
filename, output_dir = folder_paths.annotated_filepath(filename)
# validation for security: prevent accessing arbitrary path if not filename:
if filename[0] == '/' or '..' in filename: return web.Response(status=400)
return web.Response(status=400)
if output_dir is None: # validation for security: prevent accessing arbitrary path
type = request.rel_url.query.get("type", "output") if filename[0] == '/' or '..' in filename:
output_dir = folder_paths.get_directory_by_type(type) return web.Response(status=400)
if output_dir is None: if output_dir is None:
return web.Response(status=400) type = request.rel_url.query.get("type", "output")
output_dir = folder_paths.get_directory_by_type(type)
if "subfolder" in request.rel_url.query: if output_dir is None:
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) return web.Response(status=400)
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
return web.Response(status=403)
output_dir = full_output_dir
filename = os.path.basename(filename) if "subfolder" in request.rel_url.query:
file = os.path.join(output_dir, filename) full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
return web.Response(status=403)
output_dir = full_output_dir
filename = os.path.basename(filename)
file = os.path.join(output_dir, filename)
if os.path.isfile(file): if os.path.isfile(file):
if 'preview' in request.rel_url.query: if 'preview' in request.rel_url.query:
@ -562,8 +594,13 @@ class PromptServer():
return web.Response(body=alpha_buffer.read(), content_type='image/png', return web.Response(body=alpha_buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""}) headers={"Content-Disposition": f"filename=\"{filename}\""})
else: else:
# Get content type from mimetype, defaulting to 'application/octet-stream' # Use the content type from asset resolution if available,
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream' # otherwise guess from the filename.
content_type = (
resolved_content_type
or mimetypes.guess_type(filename)[0]
or 'application/octet-stream'
)
# For security, force certain mimetypes to download instead of display # For security, force certain mimetypes to download instead of display
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}: if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:

View File

@ -0,0 +1,57 @@
"""Test that Alembic migrations run cleanly on a file-backed SQLite DB.
This catches problems like unnamed FK constraints that prevent batch-mode
drop_constraint from working on real SQLite files (see MB-2).
Migrations 0001 and 0002 are already shipped, so we only exercise
upgrade/downgrade for 0003+.
"""
import os
import pytest
from alembic import command
from alembic.config import Config
# Oldest shipped revision — we upgrade to here as a baseline and never
# downgrade past it.
_BASELINE = "0002_merge_to_asset_references"
def _make_config(db_path: str) -> Config:
root = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root, "alembic_db"))
cfg = Config(config_path)
cfg.set_main_option("script_location", scripts_path)
cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
return cfg
@pytest.fixture
def migration_db(tmp_path):
"""Yield an alembic Config pre-upgraded to the baseline revision."""
db_path = str(tmp_path / "test_migration.db")
cfg = _make_config(db_path)
command.upgrade(cfg, _BASELINE)
yield cfg
def test_upgrade_to_head(migration_db):
"""Upgrade from baseline to head must succeed on a file-backed DB."""
command.upgrade(migration_db, "head")
def test_downgrade_to_baseline(migration_db):
"""Upgrade to head then downgrade back to baseline."""
command.upgrade(migration_db, "head")
command.downgrade(migration_db, _BASELINE)
def test_upgrade_downgrade_cycle(migration_db):
"""Full cycle: upgrade → downgrade → upgrade again."""
command.upgrade(migration_db, "head")
command.downgrade(migration_db, _BASELINE)
command.upgrade(migration_db, "head")

View File

@ -10,6 +10,7 @@ from app.assets.database.queries import (
get_asset_by_hash, get_asset_by_hash,
upsert_asset, upsert_asset,
bulk_insert_assets, bulk_insert_assets,
update_asset_hash_and_mime,
) )
@ -142,3 +143,45 @@ class TestBulkInsertAssets:
session.commit() session.commit()
assert session.query(Asset).count() == 200 assert session.query(Asset).count() == 200
class TestMimeTypeImmutability:
"""mime_type on Asset is write-once: set on first ingest, never overwritten."""
@pytest.mark.parametrize(
"initial_mime,second_mime,expected_mime",
[
("image/png", "image/jpeg", "image/png"),
(None, "image/png", "image/png"),
],
ids=["preserves_existing", "fills_null"],
)
def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime):
h = f"blake3:upsert_{initial_mime}_{second_mime}"
upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime)
session.commit()
asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime)
assert created is False
assert asset.mime_type == expected_mime
@pytest.mark.parametrize(
"initial_mime,update_mime,update_hash,expected_mime,expected_hash",
[
(None, "image/png", None, "image/png", "blake3:upd0"),
("image/png", "image/jpeg", None, "image/png", "blake3:upd1"),
("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"),
],
ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"],
)
def test_update_asset_hash_and_mime_immutability(
self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash,
):
h = expected_hash.removesuffix("_new")
asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime)
session.add(asset)
session.flush()
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash)
assert asset.mime_type == expected_mime
assert asset.hash == expected_hash

View File

@ -242,22 +242,24 @@ class TestSetReferencePreview:
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash") preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
session.commit() session.commit()
set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id) set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id)
session.commit() session.commit()
session.refresh(ref) session.refresh(ref)
assert ref.preview_id == preview_asset.id assert ref.preview_id == preview_ref.id
def test_clears_preview(self, session: Session): def test_clears_preview(self, session: Session):
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash") preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
ref.preview_id = preview_asset.id preview_ref = _make_reference(session, preview_asset, name="preview.png")
ref.preview_id = preview_ref.id
session.commit() session.commit()
set_reference_preview(session, reference_id=ref.id, preview_asset_id=None) set_reference_preview(session, reference_id=ref.id, preview_reference_id=None)
session.commit() session.commit()
session.refresh(ref) session.refresh(ref)
@ -265,15 +267,15 @@ class TestSetReferencePreview:
def test_raises_for_nonexistent_reference(self, session: Session): def test_raises_for_nonexistent_reference(self, session: Session):
with pytest.raises(ValueError, match="not found"): with pytest.raises(ValueError, match="not found"):
set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None) set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None)
def test_raises_for_nonexistent_preview(self, session: Session): def test_raises_for_nonexistent_preview(self, session: Session):
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
session.commit() session.commit()
with pytest.raises(ValueError, match="Preview Asset"): with pytest.raises(ValueError, match="Preview AssetReference"):
set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent") set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent")
class TestInsertReference: class TestInsertReference:
@ -351,13 +353,14 @@ class TestUpdateReferenceTimestamps:
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash") preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
session.commit() session.commit()
update_reference_timestamps(session, ref, preview_id=preview_asset.id) update_reference_timestamps(session, ref, preview_id=preview_ref.id)
session.commit() session.commit()
session.refresh(ref) session.refresh(ref)
assert ref.preview_id == preview_asset.id assert ref.preview_id == preview_ref.id
class TestSetReferenceMetadata: class TestSetReferenceMetadata:

View File

@ -20,6 +20,7 @@ def _make_reference(
asset: Asset, asset: Asset,
name: str, name: str,
metadata: dict | None = None, metadata: dict | None = None,
system_metadata: dict | None = None,
) -> AssetReference: ) -> AssetReference:
now = get_utc_now() now = get_utc_now()
ref = AssetReference( ref = AssetReference(
@ -27,6 +28,7 @@ def _make_reference(
name=name, name=name,
asset_id=asset.id, asset_id=asset.id,
user_metadata=metadata, user_metadata=metadata,
system_metadata=system_metadata,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
last_access_time=now, last_access_time=now,
@ -34,8 +36,10 @@ def _make_reference(
session.add(ref) session.add(ref)
session.flush() session.flush()
if metadata: # Build merged projection: {**system_metadata, **user_metadata}
for key, val in metadata.items(): merged = {**(system_metadata or {}), **(metadata or {})}
if merged:
for key, val in merged.items():
for row in convert_metadata_to_rows(key, val): for row in convert_metadata_to_rows(key, val):
meta_row = AssetReferenceMeta( meta_row = AssetReferenceMeta(
asset_reference_id=ref.id, asset_reference_id=ref.id,
@ -182,3 +186,46 @@ class TestMetadataFilterEmptyDict:
refs, _, total = list_references_page(session, metadata_filter={}) refs, _, total = list_references_page(session, metadata_filter={})
assert total == 2 assert total == 2
class TestSystemMetadataProjection:
"""Tests for system_metadata merging into the filter projection."""
def test_system_metadata_keys_are_filterable(self, session: Session):
"""system_metadata keys should appear in the merged projection."""
asset = _make_asset(session, "hash1")
_make_reference(
session, asset, "with_sys",
system_metadata={"source": "scanner"},
)
_make_reference(session, asset, "without_sys")
session.commit()
refs, _, total = list_references_page(
session, metadata_filter={"source": "scanner"}
)
assert total == 1
assert refs[0].name == "with_sys"
def test_user_metadata_overrides_system_metadata(self, session: Session):
"""user_metadata should win when both have the same key."""
asset = _make_asset(session, "hash1")
_make_reference(
session, asset, "overridden",
metadata={"origin": "user_upload"},
system_metadata={"origin": "auto_scan"},
)
session.commit()
# Should match the user value, not the system value
refs, _, total = list_references_page(
session, metadata_filter={"origin": "user_upload"}
)
assert total == 1
assert refs[0].name == "overridden"
# Should NOT match the system value (it was overridden)
refs, _, total = list_references_page(
session, metadata_filter={"origin": "auto_scan"}
)
assert total == 0

View File

@ -11,6 +11,7 @@ from app.assets.services import (
delete_asset_reference, delete_asset_reference,
set_asset_preview, set_asset_preview,
) )
from app.assets.services.asset_management import resolve_hash_to_path
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset: def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
@ -219,31 +220,33 @@ class TestSetAssetPreview:
asset = _make_asset(session, hash_val="blake3:main") asset = _make_asset(session, hash_val="blake3:main")
preview_asset = _make_asset(session, hash_val="blake3:preview") preview_asset = _make_asset(session, hash_val="blake3:preview")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
ref_id = ref.id ref_id = ref.id
preview_id = preview_asset.id preview_ref_id = preview_ref.id
session.commit() session.commit()
set_asset_preview( set_asset_preview(
reference_id=ref_id, reference_id=ref_id,
preview_asset_id=preview_id, preview_reference_id=preview_ref_id,
) )
# Verify by re-fetching from DB # Verify by re-fetching from DB
session.expire_all() session.expire_all()
updated_ref = session.get(AssetReference, ref_id) updated_ref = session.get(AssetReference, ref_id)
assert updated_ref.preview_id == preview_id assert updated_ref.preview_id == preview_ref_id
def test_clears_preview(self, mock_create_session, session: Session): def test_clears_preview(self, mock_create_session, session: Session):
asset = _make_asset(session) asset = _make_asset(session)
preview_asset = _make_asset(session, hash_val="blake3:preview") preview_asset = _make_asset(session, hash_val="blake3:preview")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
ref.preview_id = preview_asset.id preview_ref = _make_reference(session, preview_asset, name="preview.png")
ref.preview_id = preview_ref.id
ref_id = ref.id ref_id = ref.id
session.commit() session.commit()
set_asset_preview( set_asset_preview(
reference_id=ref_id, reference_id=ref_id,
preview_asset_id=None, preview_reference_id=None,
) )
# Verify by re-fetching from DB # Verify by re-fetching from DB
@ -263,6 +266,45 @@ class TestSetAssetPreview:
with pytest.raises(PermissionError, match="not owner"): with pytest.raises(PermissionError, match="not owner"):
set_asset_preview( set_asset_preview(
reference_id=ref.id, reference_id=ref.id,
preview_asset_id=None, preview_reference_id=None,
owner_id="user2", owner_id="user2",
) )
class TestResolveHashToPath:
def test_returns_none_for_unknown_hash(self, mock_create_session):
result = resolve_hash_to_path("blake3:" + "a" * 64)
assert result is None
@pytest.mark.parametrize(
"ref_owner, query_owner, expect_found",
[
("user1", "user1", True),
("user1", "user2", False),
("", "anyone", True),
("", "", True),
],
ids=[
"owner_sees_own_ref",
"other_owner_blocked",
"ownerless_visible_to_anyone",
"ownerless_visible_to_empty",
],
)
def test_owner_visibility(
self, ref_owner, query_owner, expect_found,
mock_create_session, session: Session, temp_dir,
):
f = temp_dir / "file.bin"
f.write_bytes(b"data")
asset = _make_asset(session, hash_val="blake3:" + "b" * 64)
ref = _make_reference(session, asset, name="file.bin", owner_id=ref_owner)
ref.file_path = str(f)
session.commit()
result = resolve_hash_to_path(asset.hash, owner_id=query_owner)
if expect_found:
assert result is not None
assert result.abs_path == str(f)
else:
assert result is None

View File

@ -113,11 +113,19 @@ class TestIngestFileFromPath:
file_path = temp_dir / "with_preview.bin" file_path = temp_dir / "with_preview.bin"
file_path.write_bytes(b"data") file_path.write_bytes(b"data")
# Create a preview asset first # Create a preview asset and reference
preview_asset = Asset(hash="blake3:preview", size_bytes=100) preview_asset = Asset(hash="blake3:preview", size_bytes=100)
session.add(preview_asset) session.add(preview_asset)
session.flush()
from app.assets.helpers import get_utc_now
now = get_utc_now()
preview_ref = AssetReference(
asset_id=preview_asset.id, name="preview.png", owner_id="",
created_at=now, updated_at=now, last_access_time=now,
)
session.add(preview_ref)
session.commit() session.commit()
preview_id = preview_asset.id preview_id = preview_ref.id
result = _ingest_file_from_path( result = _ingest_file_from_path(
abs_path=str(file_path), abs_path=str(file_path),

View File

@ -0,0 +1,123 @@
"""Tests for list_tag_histogram service function."""
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference
from app.assets.helpers import get_utc_now
from app.assets.services.tagging import list_tag_histogram
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_reference(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
class TestListTagHistogram:
def test_returns_counts_for_all_tags(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["alpha", "beta"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="r1")
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha", "beta"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="r2")
add_tags_to_reference(session, reference_id=r2.id, tags=["alpha"])
session.commit()
result = list_tag_histogram()
assert result["alpha"] == 2
assert result["beta"] == 1
def test_empty_when_no_assets(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["unused"])
session.commit()
result = list_tag_histogram()
assert result == {}
def test_include_tags_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["models", "loras", "input"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="r1")
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="r2")
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
session.commit()
result = list_tag_histogram(include_tags=["models"])
# Only r1 has "models", so only its tags appear
assert "models" in result
assert "loras" in result
assert "input" not in result
def test_exclude_tags_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["models", "loras", "input"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="r1")
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="r2")
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
session.commit()
result = list_tag_histogram(exclude_tags=["models"])
# r1 excluded, only r2's tags remain
assert "input" in result
assert "loras" not in result
def test_name_contains_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["alpha", "beta"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="my_model.safetensors")
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="picture.png")
add_tags_to_reference(session, reference_id=r2.id, tags=["beta"])
session.commit()
result = list_tag_histogram(name_contains="model")
assert "alpha" in result
assert "beta" not in result
def test_limit_caps_results(self, mock_create_session, session: Session):
tags = [f"tag{i}" for i in range(10)]
ensure_tags_exist(session, tags)
a = _make_asset(session, "blake3:aaa")
r = _make_reference(session, a, name="r1")
add_tags_to_reference(session, reference_id=r.id, tags=tags)
session.commit()
result = list_tag_histogram(limit=3)
assert len(result) == 3

View File

@ -243,6 +243,15 @@ def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):
files = {"file": ("notags.bin", b"A" * 64, "application/octet-stream")}
form = {"tags": json.dumps([]), "name": "notags.bin", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.parametrize("root", ["input", "output"]) @pytest.mark.parametrize("root", ["input", "output"])
def test_duplicate_upload_same_display_name_does_not_clobber( def test_duplicate_upload_same_display_name_does_not_clobber(
root: str, root: str,

View File

@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
}, },
# JavaScript/CSS scenarios # JavaScript/CSS scenarios
{ {
"name": "js_no_cache", "name": "js_no_store",
"path": "/script.js", "path": "/script.js",
"status": 200, "status": 200,
"expected_cache": "no-cache", "expected_cache": "no-store",
"should_have_header": True, "should_have_header": True,
}, },
{ {
"name": "css_no_cache", "name": "css_no_store",
"path": "/styles.css", "path": "/styles.css",
"status": 200, "status": 200,
"expected_cache": "no-cache", "expected_cache": "no-store",
"should_have_header": True, "should_have_header": True,
}, },
{ {
"name": "index_json_no_cache", "name": "index_json_no_store",
"path": "/api/index.json", "path": "/api/index.json",
"status": 200, "status": 200,
"expected_cache": "no-cache", "expected_cache": "no-store",
"should_have_header": True, "should_have_header": True,
}, },
{ {
"name": "localized_index_json_no_cache", "name": "localized_index_json_no_store",
"path": "/templates/index.zh.json", "path": "/templates/index.zh.json",
"status": 200, "status": 200,
"expected_cache": "no-cache", "expected_cache": "no-store",
"should_have_header": True, "should_have_header": True,
}, },
# Non-matching files # Non-matching files