mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-05-04 22:32:32 +08:00
Merge upstream/master, keep local README.md
This commit is contained in:
commit
dd36f8d9f1
103
.github/scripts/check-ai-co-authors.sh
vendored
Executable file
103
.github/scripts/check-ai-co-authors.sh
vendored
Executable file
@ -0,0 +1,103 @@
|
|||||||
|
#!/usr/bin/env bash
|
||||||
|
# Checks pull request commits for AI agent Co-authored-by trailers.
|
||||||
|
# Exits non-zero when any are found and prints fix instructions.
|
||||||
|
set -euo pipefail
|
||||||
|
|
||||||
|
base_sha="${1:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||||
|
head_sha="${2:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
|
||||||
|
|
||||||
|
# Known AI coding-agent trailer patterns (case-insensitive).
|
||||||
|
# Each entry is an extended-regex fragment matched against Co-authored-by lines.
|
||||||
|
AGENT_PATTERNS=(
|
||||||
|
# Anthropic — Claude Code / Amp
|
||||||
|
'noreply@anthropic\.com'
|
||||||
|
# Cursor
|
||||||
|
'cursoragent@cursor\.com'
|
||||||
|
# GitHub Copilot
|
||||||
|
'copilot-swe-agent\[bot\]'
|
||||||
|
'copilot@github\.com'
|
||||||
|
# OpenAI Codex
|
||||||
|
'noreply@openai\.com'
|
||||||
|
'codex@openai\.com'
|
||||||
|
# Aider
|
||||||
|
'aider@aider\.chat'
|
||||||
|
# Google — Gemini / Jules
|
||||||
|
'gemini@google\.com'
|
||||||
|
'jules@google\.com'
|
||||||
|
# Windsurf / Codeium
|
||||||
|
'@codeium\.com'
|
||||||
|
# Devin
|
||||||
|
'devin-ai-integration\[bot\]'
|
||||||
|
'devin@cognition\.ai'
|
||||||
|
'devin@cognition-labs\.com'
|
||||||
|
# Amazon Q Developer
|
||||||
|
'amazon-q-developer'
|
||||||
|
'@amazon\.com.*[Qq].[Dd]eveloper'
|
||||||
|
# Cline
|
||||||
|
'cline-bot'
|
||||||
|
'cline@cline\.ai'
|
||||||
|
# Continue
|
||||||
|
'continue-agent'
|
||||||
|
'continue@continue\.dev'
|
||||||
|
# Sourcegraph
|
||||||
|
'noreply@sourcegraph\.com'
|
||||||
|
# Generic catch-alls for common agent name patterns
|
||||||
|
'Co-authored-by:.*\b[Cc]laude\b'
|
||||||
|
'Co-authored-by:.*\b[Cc]opilot\b'
|
||||||
|
'Co-authored-by:.*\b[Cc]ursor\b'
|
||||||
|
'Co-authored-by:.*\b[Cc]odex\b'
|
||||||
|
'Co-authored-by:.*\b[Gg]emini\b'
|
||||||
|
'Co-authored-by:.*\b[Aa]ider\b'
|
||||||
|
'Co-authored-by:.*\b[Dd]evin\b'
|
||||||
|
'Co-authored-by:.*\b[Ww]indsurf\b'
|
||||||
|
'Co-authored-by:.*\b[Cc]line\b'
|
||||||
|
'Co-authored-by:.*\b[Aa]mazon Q\b'
|
||||||
|
'Co-authored-by:.*\b[Jj]ules\b'
|
||||||
|
'Co-authored-by:.*\bOpenCode\b'
|
||||||
|
)
|
||||||
|
|
||||||
|
# Build a single alternation regex from all patterns.
|
||||||
|
regex=""
|
||||||
|
for pattern in "${AGENT_PATTERNS[@]}"; do
|
||||||
|
if [[ -n "$regex" ]]; then
|
||||||
|
regex="${regex}|${pattern}"
|
||||||
|
else
|
||||||
|
regex="$pattern"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
# Collect Co-authored-by lines from every commit in the PR range.
|
||||||
|
violations=""
|
||||||
|
while IFS= read -r sha; do
|
||||||
|
message="$(git log -1 --format='%B' "$sha")"
|
||||||
|
matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
|
||||||
|
if [[ -z "$matched_lines" ]]; then
|
||||||
|
continue
|
||||||
|
fi
|
||||||
|
|
||||||
|
while IFS= read -r line; do
|
||||||
|
if echo "$line" | grep -iqE "$regex"; then
|
||||||
|
short="$(git log -1 --format='%h' "$sha")"
|
||||||
|
violations="${violations} ${short}: ${line}"$'\n'
|
||||||
|
fi
|
||||||
|
done <<< "$matched_lines"
|
||||||
|
done < <(git rev-list "${base_sha}..${head_sha}")
|
||||||
|
|
||||||
|
if [[ -n "$violations" ]]; then
|
||||||
|
echo "::error::AI agent Co-authored-by trailers detected in PR commits."
|
||||||
|
echo ""
|
||||||
|
echo "The following commits contain Co-authored-by trailers from AI coding agents:"
|
||||||
|
echo ""
|
||||||
|
echo "$violations"
|
||||||
|
echo "These trailers should be removed before merging."
|
||||||
|
echo ""
|
||||||
|
echo "To fix, rewrite the commit messages with:"
|
||||||
|
echo " git rebase -i ${base_sha}"
|
||||||
|
echo ""
|
||||||
|
echo "and remove the Co-authored-by lines, then force-push your branch."
|
||||||
|
echo ""
|
||||||
|
echo "If you believe this is a false positive, please open an issue."
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
|
echo "No AI agent Co-authored-by trailers found."
|
||||||
19
.github/workflows/check-ai-co-authors.yml
vendored
Normal file
19
.github/workflows/check-ai-co-authors.yml
vendored
Normal file
@ -0,0 +1,19 @@
|
|||||||
|
name: Check AI Co-Authors
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches: ['*']
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-ai-co-authors:
|
||||||
|
name: Check for AI agent co-author trailers
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
with:
|
||||||
|
fetch-depth: 0
|
||||||
|
|
||||||
|
- name: Check commits for AI co-author trailers
|
||||||
|
run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"
|
||||||
@ -8,7 +8,7 @@ from alembic import context
|
|||||||
config = context.config
|
config = context.config
|
||||||
|
|
||||||
|
|
||||||
from app.database.models import Base
|
from app.database.models import Base, NAMING_CONVENTION
|
||||||
target_metadata = Base.metadata
|
target_metadata = Base.metadata
|
||||||
|
|
||||||
# other values from the config, defined by the needs of env.py,
|
# other values from the config, defined by the needs of env.py,
|
||||||
@ -51,7 +51,10 @@ def run_migrations_online() -> None:
|
|||||||
|
|
||||||
with connectable.connect() as connection:
|
with connectable.connect() as connection:
|
||||||
context.configure(
|
context.configure(
|
||||||
connection=connection, target_metadata=target_metadata
|
connection=connection,
|
||||||
|
target_metadata=target_metadata,
|
||||||
|
render_as_batch=True,
|
||||||
|
naming_convention=NAMING_CONVENTION,
|
||||||
)
|
)
|
||||||
|
|
||||||
with context.begin_transaction():
|
with context.begin_transaction():
|
||||||
|
|||||||
98
alembic_db/versions/0003_add_metadata_job_id.py
Normal file
98
alembic_db/versions/0003_add_metadata_job_id.py
Normal file
@ -0,0 +1,98 @@
|
|||||||
|
"""
|
||||||
|
Add system_metadata and job_id columns to asset_references.
|
||||||
|
Change preview_id FK from assets.id to asset_references.id.
|
||||||
|
|
||||||
|
Revision ID: 0003_add_metadata_job_id
|
||||||
|
Revises: 0002_merge_to_asset_references
|
||||||
|
Create Date: 2026-03-09
|
||||||
|
"""
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from app.database.models import NAMING_CONVENTION
|
||||||
|
|
||||||
|
revision = "0003_add_metadata_job_id"
|
||||||
|
down_revision = "0002_merge_to_asset_references"
|
||||||
|
branch_labels = None
|
||||||
|
depends_on = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
with op.batch_alter_table("asset_references") as batch_op:
|
||||||
|
batch_op.add_column(
|
||||||
|
sa.Column("system_metadata", sa.JSON(), nullable=True)
|
||||||
|
)
|
||||||
|
batch_op.add_column(
|
||||||
|
sa.Column("job_id", sa.String(length=36), nullable=True)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Change preview_id FK from assets.id to asset_references.id (self-ref).
|
||||||
|
# Existing values are asset-content IDs that won't match reference IDs,
|
||||||
|
# so null them out first.
|
||||||
|
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
|
||||||
|
with op.batch_alter_table(
|
||||||
|
"asset_references", naming_convention=NAMING_CONVENTION
|
||||||
|
) as batch_op:
|
||||||
|
batch_op.drop_constraint(
|
||||||
|
"fk_asset_references_preview_id_assets", type_="foreignkey"
|
||||||
|
)
|
||||||
|
batch_op.create_foreign_key(
|
||||||
|
"fk_asset_references_preview_id_asset_references",
|
||||||
|
"asset_references",
|
||||||
|
["preview_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
batch_op.create_index(
|
||||||
|
"ix_asset_references_preview_id", ["preview_id"]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Purge any all-null meta rows before adding the constraint
|
||||||
|
op.execute(
|
||||||
|
"DELETE FROM asset_reference_meta"
|
||||||
|
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
|
||||||
|
)
|
||||||
|
with op.batch_alter_table("asset_reference_meta") as batch_op:
|
||||||
|
batch_op.create_check_constraint(
|
||||||
|
"ck_asset_reference_meta_has_value",
|
||||||
|
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
# SQLite doesn't reflect CHECK constraints, so we must declare it
|
||||||
|
# explicitly via table_args for the batch recreate to find it.
|
||||||
|
# Use the fully-rendered constraint name to avoid the naming convention
|
||||||
|
# doubling the prefix.
|
||||||
|
with op.batch_alter_table(
|
||||||
|
"asset_reference_meta",
|
||||||
|
table_args=[
|
||||||
|
sa.CheckConstraint(
|
||||||
|
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||||
|
name="ck_asset_reference_meta_has_value",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
) as batch_op:
|
||||||
|
batch_op.drop_constraint(
|
||||||
|
"ck_asset_reference_meta_has_value", type_="check"
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table(
|
||||||
|
"asset_references", naming_convention=NAMING_CONVENTION
|
||||||
|
) as batch_op:
|
||||||
|
batch_op.drop_index("ix_asset_references_preview_id")
|
||||||
|
batch_op.drop_constraint(
|
||||||
|
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
|
||||||
|
)
|
||||||
|
batch_op.create_foreign_key(
|
||||||
|
"fk_asset_references_preview_id_assets",
|
||||||
|
"assets",
|
||||||
|
["preview_id"],
|
||||||
|
["id"],
|
||||||
|
ondelete="SET NULL",
|
||||||
|
)
|
||||||
|
|
||||||
|
with op.batch_alter_table("asset_references") as batch_op:
|
||||||
|
batch_op.drop_column("job_id")
|
||||||
|
batch_op.drop_column("system_metadata")
|
||||||
@ -13,6 +13,7 @@ from pydantic import ValidationError
|
|||||||
import folder_paths
|
import folder_paths
|
||||||
from app import user_manager
|
from app import user_manager
|
||||||
from app.assets.api import schemas_in, schemas_out
|
from app.assets.api import schemas_in, schemas_out
|
||||||
|
from app.assets.services import schemas
|
||||||
from app.assets.api.schemas_in import (
|
from app.assets.api.schemas_in import (
|
||||||
AssetValidationError,
|
AssetValidationError,
|
||||||
UploadError,
|
UploadError,
|
||||||
@ -38,6 +39,7 @@ from app.assets.services import (
|
|||||||
update_asset_metadata,
|
update_asset_metadata,
|
||||||
upload_from_temp_path,
|
upload_from_temp_path,
|
||||||
)
|
)
|
||||||
|
from app.assets.services.tagging import list_tag_histogram
|
||||||
|
|
||||||
ROUTES = web.RouteTableDef()
|
ROUTES = web.RouteTableDef()
|
||||||
USER_MANAGER: user_manager.UserManager | None = None
|
USER_MANAGER: user_manager.UserManager | None = None
|
||||||
@ -122,6 +124,61 @@ def _validate_sort_field(requested: str | None) -> str:
|
|||||||
return "created_at"
|
return "created_at"
|
||||||
|
|
||||||
|
|
||||||
|
def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None:
|
||||||
|
"""Build a /api/view preview URL from asset tags and user_metadata filename."""
|
||||||
|
if not user_metadata:
|
||||||
|
return None
|
||||||
|
filename = user_metadata.get("filename")
|
||||||
|
if not filename:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if "input" in tags:
|
||||||
|
view_type = "input"
|
||||||
|
elif "output" in tags:
|
||||||
|
view_type = "output"
|
||||||
|
else:
|
||||||
|
return None
|
||||||
|
|
||||||
|
subfolder = ""
|
||||||
|
if "/" in filename:
|
||||||
|
subfolder, filename = filename.rsplit("/", 1)
|
||||||
|
|
||||||
|
encoded_filename = urllib.parse.quote(filename, safe="")
|
||||||
|
url = f"/api/view?type={view_type}&filename={encoded_filename}"
|
||||||
|
if subfolder:
|
||||||
|
url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}"
|
||||||
|
return url
|
||||||
|
|
||||||
|
|
||||||
|
def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
|
||||||
|
"""Build an Asset response from a service result."""
|
||||||
|
if result.ref.preview_id:
|
||||||
|
preview_detail = get_asset_detail(result.ref.preview_id)
|
||||||
|
if preview_detail:
|
||||||
|
preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata)
|
||||||
|
else:
|
||||||
|
preview_url = None
|
||||||
|
else:
|
||||||
|
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
|
||||||
|
return schemas_out.Asset(
|
||||||
|
id=result.ref.id,
|
||||||
|
name=result.ref.name,
|
||||||
|
asset_hash=result.asset.hash if result.asset else None,
|
||||||
|
size=int(result.asset.size_bytes) if result.asset else None,
|
||||||
|
mime_type=result.asset.mime_type if result.asset else None,
|
||||||
|
tags=result.tags,
|
||||||
|
preview_url=preview_url,
|
||||||
|
preview_id=result.ref.preview_id,
|
||||||
|
user_metadata=result.ref.user_metadata or {},
|
||||||
|
metadata=result.ref.system_metadata,
|
||||||
|
job_id=result.ref.job_id,
|
||||||
|
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
|
||||||
|
created_at=result.ref.created_at,
|
||||||
|
updated_at=result.ref.updated_at,
|
||||||
|
last_access_time=result.ref.last_access_time,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.head("/api/assets/hash/{hash}")
|
@ROUTES.head("/api/assets/hash/{hash}")
|
||||||
@_require_assets_feature_enabled
|
@_require_assets_feature_enabled
|
||||||
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
async def head_asset_by_hash(request: web.Request) -> web.Response:
|
||||||
@ -164,20 +221,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
|||||||
order=order,
|
order=order,
|
||||||
)
|
)
|
||||||
|
|
||||||
summaries = [
|
summaries = [_build_asset_response(item) for item in result.items]
|
||||||
schemas_out.AssetSummary(
|
|
||||||
id=item.ref.id,
|
|
||||||
name=item.ref.name,
|
|
||||||
asset_hash=item.asset.hash if item.asset else None,
|
|
||||||
size=int(item.asset.size_bytes) if item.asset else None,
|
|
||||||
mime_type=item.asset.mime_type if item.asset else None,
|
|
||||||
tags=item.tags,
|
|
||||||
created_at=item.ref.created_at,
|
|
||||||
updated_at=item.ref.updated_at,
|
|
||||||
last_access_time=item.ref.last_access_time,
|
|
||||||
)
|
|
||||||
for item in result.items
|
|
||||||
]
|
|
||||||
|
|
||||||
payload = schemas_out.AssetsList(
|
payload = schemas_out.AssetsList(
|
||||||
assets=summaries,
|
assets=summaries,
|
||||||
@ -207,18 +251,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
|||||||
{"id": reference_id},
|
{"id": reference_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
payload = schemas_out.AssetDetail(
|
payload = _build_asset_response(result)
|
||||||
id=result.ref.id,
|
|
||||||
name=result.ref.name,
|
|
||||||
asset_hash=result.asset.hash if result.asset else None,
|
|
||||||
size=int(result.asset.size_bytes) if result.asset else None,
|
|
||||||
mime_type=result.asset.mime_type if result.asset else None,
|
|
||||||
tags=result.tags,
|
|
||||||
user_metadata=result.ref.user_metadata or {},
|
|
||||||
preview_id=result.ref.preview_id,
|
|
||||||
created_at=result.ref.created_at,
|
|
||||||
last_access_time=result.ref.last_access_time,
|
|
||||||
)
|
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
return _build_error_response(
|
return _build_error_response(
|
||||||
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
|
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
|
||||||
@ -230,7 +263,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
|
|||||||
USER_MANAGER.get_request_user_id(request),
|
USER_MANAGER.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
|
||||||
@ -312,32 +345,31 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
|
|||||||
400, "INVALID_JSON", "Request body must be valid JSON."
|
400, "INVALID_JSON", "Request body must be valid JSON."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Derive name from hash if not provided
|
||||||
|
name = body.name
|
||||||
|
if name is None:
|
||||||
|
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
|
||||||
|
|
||||||
result = create_from_hash(
|
result = create_from_hash(
|
||||||
hash_str=body.hash,
|
hash_str=body.hash,
|
||||||
name=body.name,
|
name=name,
|
||||||
tags=body.tags,
|
tags=body.tags,
|
||||||
user_metadata=body.user_metadata,
|
user_metadata=body.user_metadata,
|
||||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
mime_type=body.mime_type,
|
||||||
|
preview_id=body.preview_id,
|
||||||
)
|
)
|
||||||
if result is None:
|
if result is None:
|
||||||
return _build_error_response(
|
return _build_error_response(
|
||||||
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
|
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
asset = _build_asset_response(result)
|
||||||
payload_out = schemas_out.AssetCreated(
|
payload_out = schemas_out.AssetCreated(
|
||||||
id=result.ref.id,
|
**asset.model_dump(),
|
||||||
name=result.ref.name,
|
|
||||||
asset_hash=result.asset.hash,
|
|
||||||
size=int(result.asset.size_bytes),
|
|
||||||
mime_type=result.asset.mime_type,
|
|
||||||
tags=result.tags,
|
|
||||||
user_metadata=result.ref.user_metadata or {},
|
|
||||||
preview_id=result.ref.preview_id,
|
|
||||||
created_at=result.ref.created_at,
|
|
||||||
last_access_time=result.ref.last_access_time,
|
|
||||||
created_new=result.created_new,
|
created_new=result.created_new,
|
||||||
)
|
)
|
||||||
return web.json_response(payload_out.model_dump(mode="json"), status=201)
|
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.post("/api/assets")
|
@ROUTES.post("/api/assets")
|
||||||
@ -358,6 +390,8 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
"name": parsed.provided_name,
|
"name": parsed.provided_name,
|
||||||
"user_metadata": parsed.user_metadata_raw,
|
"user_metadata": parsed.user_metadata_raw,
|
||||||
"hash": parsed.provided_hash,
|
"hash": parsed.provided_hash,
|
||||||
|
"mime_type": parsed.provided_mime_type,
|
||||||
|
"preview_id": parsed.provided_preview_id,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
except ValidationError as ve:
|
except ValidationError as ve:
|
||||||
@ -386,6 +420,8 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
tags=spec.tags,
|
tags=spec.tags,
|
||||||
user_metadata=spec.user_metadata or {},
|
user_metadata=spec.user_metadata or {},
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
|
mime_type=spec.mime_type,
|
||||||
|
preview_id=spec.preview_id,
|
||||||
)
|
)
|
||||||
if result is None:
|
if result is None:
|
||||||
delete_temp_file_if_exists(parsed.tmp_path)
|
delete_temp_file_if_exists(parsed.tmp_path)
|
||||||
@ -410,6 +446,8 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
client_filename=parsed.file_client_name,
|
client_filename=parsed.file_client_name,
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
expected_hash=spec.hash,
|
expected_hash=spec.hash,
|
||||||
|
mime_type=spec.mime_type,
|
||||||
|
preview_id=spec.preview_id,
|
||||||
)
|
)
|
||||||
except AssetValidationError as e:
|
except AssetValidationError as e:
|
||||||
delete_temp_file_if_exists(parsed.tmp_path)
|
delete_temp_file_if_exists(parsed.tmp_path)
|
||||||
@ -428,21 +466,13 @@ async def upload_asset(request: web.Request) -> web.Response:
|
|||||||
logging.exception("upload_asset failed for owner_id=%s", owner_id)
|
logging.exception("upload_asset failed for owner_id=%s", owner_id)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
payload = schemas_out.AssetCreated(
|
asset = _build_asset_response(result)
|
||||||
id=result.ref.id,
|
payload_out = schemas_out.AssetCreated(
|
||||||
name=result.ref.name,
|
**asset.model_dump(),
|
||||||
asset_hash=result.asset.hash,
|
|
||||||
size=int(result.asset.size_bytes),
|
|
||||||
mime_type=result.asset.mime_type,
|
|
||||||
tags=result.tags,
|
|
||||||
user_metadata=result.ref.user_metadata or {},
|
|
||||||
preview_id=result.ref.preview_id,
|
|
||||||
created_at=result.ref.created_at,
|
|
||||||
last_access_time=result.ref.last_access_time,
|
|
||||||
created_new=result.created_new,
|
created_new=result.created_new,
|
||||||
)
|
)
|
||||||
status = 201 if result.created_new else 200
|
status = 201 if result.created_new else 200
|
||||||
return web.json_response(payload.model_dump(mode="json"), status=status)
|
return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
@ -464,15 +494,9 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
|||||||
name=body.name,
|
name=body.name,
|
||||||
user_metadata=body.user_metadata,
|
user_metadata=body.user_metadata,
|
||||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
preview_id=body.preview_id,
|
||||||
)
|
)
|
||||||
payload = schemas_out.AssetUpdated(
|
payload = _build_asset_response(result)
|
||||||
id=result.ref.id,
|
|
||||||
name=result.ref.name,
|
|
||||||
asset_hash=result.asset.hash if result.asset else None,
|
|
||||||
tags=result.tags,
|
|
||||||
user_metadata=result.ref.user_metadata or {},
|
|
||||||
updated_at=result.ref.updated_at,
|
|
||||||
)
|
|
||||||
except PermissionError as pe:
|
except PermissionError as pe:
|
||||||
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
|
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
|
||||||
except ValueError as ve:
|
except ValueError as ve:
|
||||||
@ -486,7 +510,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
|||||||
USER_MANAGER.get_request_user_id(request),
|
USER_MANAGER.get_request_user_id(request),
|
||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
|
||||||
@ -555,7 +579,7 @@ async def get_tags(request: web.Request) -> web.Response:
|
|||||||
payload = schemas_out.TagsList(
|
payload = schemas_out.TagsList(
|
||||||
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
||||||
)
|
)
|
||||||
return web.json_response(payload.model_dump(mode="json"))
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||||
@ -603,7 +627,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
|
|||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
|
||||||
@ -650,7 +674,29 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
|
|||||||
)
|
)
|
||||||
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
return _build_error_response(500, "INTERNAL", "Unexpected server error.")
|
||||||
|
|
||||||
return web.json_response(payload.model_dump(mode="json"), status=200)
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||||
|
|
||||||
|
|
||||||
|
@ROUTES.get("/api/assets/tags/refine")
|
||||||
|
@_require_assets_feature_enabled
|
||||||
|
async def get_tags_refine(request: web.Request) -> web.Response:
|
||||||
|
"""GET request to get tag histogram for filtered assets."""
|
||||||
|
query_dict = get_query_dict(request)
|
||||||
|
try:
|
||||||
|
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
|
||||||
|
except ValidationError as ve:
|
||||||
|
return _build_validation_error_response("INVALID_QUERY", ve)
|
||||||
|
|
||||||
|
tag_counts = list_tag_histogram(
|
||||||
|
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||||
|
include_tags=q.include_tags,
|
||||||
|
exclude_tags=q.exclude_tags,
|
||||||
|
name_contains=q.name_contains,
|
||||||
|
metadata_filter=q.metadata_filter,
|
||||||
|
limit=q.limit,
|
||||||
|
)
|
||||||
|
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
|
||||||
|
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
|
||||||
|
|
||||||
|
|
||||||
@ROUTES.post("/api/assets/seed")
|
@ROUTES.post("/api/assets/seed")
|
||||||
|
|||||||
@ -45,6 +45,8 @@ class ParsedUpload:
|
|||||||
user_metadata_raw: str | None
|
user_metadata_raw: str | None
|
||||||
provided_hash: str | None
|
provided_hash: str | None
|
||||||
provided_hash_exists: bool | None
|
provided_hash_exists: bool | None
|
||||||
|
provided_mime_type: str | None = None
|
||||||
|
provided_preview_id: str | None = None
|
||||||
|
|
||||||
|
|
||||||
class ListAssetsQuery(BaseModel):
|
class ListAssetsQuery(BaseModel):
|
||||||
@ -98,11 +100,17 @@ class ListAssetsQuery(BaseModel):
|
|||||||
class UpdateAssetBody(BaseModel):
|
class UpdateAssetBody(BaseModel):
|
||||||
name: str | None = None
|
name: str | None = None
|
||||||
user_metadata: dict[str, Any] | None = None
|
user_metadata: dict[str, Any] | None = None
|
||||||
|
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_at_least_one_field(self):
|
def _validate_at_least_one_field(self):
|
||||||
if self.name is None and self.user_metadata is None:
|
if all(
|
||||||
raise ValueError("Provide at least one of: name, user_metadata.")
|
v is None
|
||||||
|
for v in (self.name, self.user_metadata, self.preview_id)
|
||||||
|
):
|
||||||
|
raise ValueError(
|
||||||
|
"Provide at least one of: name, user_metadata, preview_id."
|
||||||
|
)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
@ -110,9 +118,11 @@ class CreateFromHashBody(BaseModel):
|
|||||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
hash: str
|
hash: str
|
||||||
name: str
|
name: str | None = None
|
||||||
tags: list[str] = Field(default_factory=list)
|
tags: list[str] = Field(default_factory=list)
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
mime_type: str | None = None
|
||||||
|
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||||
|
|
||||||
@field_validator("hash")
|
@field_validator("hash")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -138,6 +148,44 @@ class CreateFromHashBody(BaseModel):
|
|||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class TagsRefineQuery(BaseModel):
|
||||||
|
include_tags: list[str] = Field(default_factory=list)
|
||||||
|
exclude_tags: list[str] = Field(default_factory=list)
|
||||||
|
name_contains: str | None = None
|
||||||
|
metadata_filter: dict[str, Any] | None = None
|
||||||
|
limit: conint(ge=1, le=1000) = 100
|
||||||
|
|
||||||
|
@field_validator("include_tags", "exclude_tags", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _split_csv_tags(cls, v):
|
||||||
|
if v is None:
|
||||||
|
return []
|
||||||
|
if isinstance(v, str):
|
||||||
|
return [t.strip() for t in v.split(",") if t.strip()]
|
||||||
|
if isinstance(v, list):
|
||||||
|
out: list[str] = []
|
||||||
|
for item in v:
|
||||||
|
if isinstance(item, str):
|
||||||
|
out.extend([t.strip() for t in item.split(",") if t.strip()])
|
||||||
|
return out
|
||||||
|
return v
|
||||||
|
|
||||||
|
@field_validator("metadata_filter", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _parse_metadata_json(cls, v):
|
||||||
|
if v is None or isinstance(v, dict):
|
||||||
|
return v
|
||||||
|
if isinstance(v, str) and v.strip():
|
||||||
|
try:
|
||||||
|
parsed = json.loads(v)
|
||||||
|
except Exception as e:
|
||||||
|
raise ValueError(f"metadata_filter must be JSON: {e}") from e
|
||||||
|
if not isinstance(parsed, dict):
|
||||||
|
raise ValueError("metadata_filter must be a JSON object")
|
||||||
|
return parsed
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
class TagsListQuery(BaseModel):
|
class TagsListQuery(BaseModel):
|
||||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
@ -186,21 +234,25 @@ class TagsRemove(TagsAdd):
|
|||||||
class UploadAssetSpec(BaseModel):
|
class UploadAssetSpec(BaseModel):
|
||||||
"""Upload Asset operation.
|
"""Upload Asset operation.
|
||||||
|
|
||||||
- tags: ordered; first is root ('models'|'input'|'output');
|
- tags: optional list; if provided, first is root ('models'|'input'|'output');
|
||||||
if root == 'models', second must be a valid category
|
if root == 'models', second must be a valid category
|
||||||
- name: display name
|
- name: display name
|
||||||
- user_metadata: arbitrary JSON object (optional)
|
- user_metadata: arbitrary JSON object (optional)
|
||||||
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
- hash: optional canonical 'blake3:<hex>' for validation / fast-path
|
||||||
|
- mime_type: optional MIME type override
|
||||||
|
- preview_id: optional asset_reference ID for preview
|
||||||
|
|
||||||
Files are stored using the content hash as filename stem.
|
Files are stored using the content hash as filename stem.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
|
||||||
|
|
||||||
tags: list[str] = Field(..., min_length=1)
|
tags: list[str] = Field(default_factory=list)
|
||||||
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
name: str | None = Field(default=None, max_length=512, description="Display Name")
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
hash: str | None = Field(default=None)
|
hash: str | None = Field(default=None)
|
||||||
|
mime_type: str | None = Field(default=None)
|
||||||
|
preview_id: str | None = Field(default=None) # references an asset_reference id
|
||||||
|
|
||||||
@field_validator("hash", mode="before")
|
@field_validator("hash", mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -279,7 +331,7 @@ class UploadAssetSpec(BaseModel):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_order(self):
|
def _validate_order(self):
|
||||||
if not self.tags:
|
if not self.tags:
|
||||||
raise ValueError("tags must be provided and non-empty")
|
raise ValueError("at least one tag is required for uploads")
|
||||||
root = self.tags[0]
|
root = self.tags[0]
|
||||||
if root not in {"models", "input", "output"}:
|
if root not in {"models", "input", "output"}:
|
||||||
raise ValueError("first tag must be one of: models, input, output")
|
raise ValueError("first tag must be one of: models, input, output")
|
||||||
|
|||||||
@ -4,7 +4,10 @@ from typing import Any
|
|||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer
|
||||||
|
|
||||||
|
|
||||||
class AssetSummary(BaseModel):
|
class Asset(BaseModel):
|
||||||
|
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
|
||||||
|
``id`` here is the AssetReference id, not the content-addressed Asset id."""
|
||||||
|
|
||||||
id: str
|
id: str
|
||||||
name: str
|
name: str
|
||||||
asset_hash: str | None = None
|
asset_hash: str | None = None
|
||||||
@ -12,8 +15,14 @@ class AssetSummary(BaseModel):
|
|||||||
mime_type: str | None = None
|
mime_type: str | None = None
|
||||||
tags: list[str] = Field(default_factory=list)
|
tags: list[str] = Field(default_factory=list)
|
||||||
preview_url: str | None = None
|
preview_url: str | None = None
|
||||||
created_at: datetime | None = None
|
preview_id: str | None = None # references an asset_reference id, not an asset id
|
||||||
updated_at: datetime | None = None
|
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||||
|
is_immutable: bool = False
|
||||||
|
metadata: dict[str, Any] | None = None
|
||||||
|
job_id: str | None = None
|
||||||
|
prompt_id: str | None = None # deprecated: use job_id
|
||||||
|
created_at: datetime
|
||||||
|
updated_at: datetime
|
||||||
last_access_time: datetime | None = None
|
last_access_time: datetime | None = None
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
model_config = ConfigDict(from_attributes=True)
|
||||||
@ -23,50 +32,16 @@ class AssetSummary(BaseModel):
|
|||||||
return v.isoformat() if v else None
|
return v.isoformat() if v else None
|
||||||
|
|
||||||
|
|
||||||
|
class AssetCreated(Asset):
|
||||||
|
created_new: bool
|
||||||
|
|
||||||
|
|
||||||
class AssetsList(BaseModel):
|
class AssetsList(BaseModel):
|
||||||
assets: list[AssetSummary]
|
assets: list[Asset]
|
||||||
total: int
|
total: int
|
||||||
has_more: bool
|
has_more: bool
|
||||||
|
|
||||||
|
|
||||||
class AssetUpdated(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
asset_hash: str | None = None
|
|
||||||
tags: list[str] = Field(default_factory=list)
|
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
updated_at: datetime | None = None
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
@field_serializer("updated_at")
|
|
||||||
def _serialize_updated_at(self, v: datetime | None, _info):
|
|
||||||
return v.isoformat() if v else None
|
|
||||||
|
|
||||||
|
|
||||||
class AssetDetail(BaseModel):
|
|
||||||
id: str
|
|
||||||
name: str
|
|
||||||
asset_hash: str | None = None
|
|
||||||
size: int | None = None
|
|
||||||
mime_type: str | None = None
|
|
||||||
tags: list[str] = Field(default_factory=list)
|
|
||||||
user_metadata: dict[str, Any] = Field(default_factory=dict)
|
|
||||||
preview_id: str | None = None
|
|
||||||
created_at: datetime | None = None
|
|
||||||
last_access_time: datetime | None = None
|
|
||||||
|
|
||||||
model_config = ConfigDict(from_attributes=True)
|
|
||||||
|
|
||||||
@field_serializer("created_at", "last_access_time")
|
|
||||||
def _serialize_datetime(self, v: datetime | None, _info):
|
|
||||||
return v.isoformat() if v else None
|
|
||||||
|
|
||||||
|
|
||||||
class AssetCreated(AssetDetail):
|
|
||||||
created_new: bool
|
|
||||||
|
|
||||||
|
|
||||||
class TagUsage(BaseModel):
|
class TagUsage(BaseModel):
|
||||||
name: str
|
name: str
|
||||||
count: int
|
count: int
|
||||||
@ -91,3 +66,7 @@ class TagsRemove(BaseModel):
|
|||||||
removed: list[str] = Field(default_factory=list)
|
removed: list[str] = Field(default_factory=list)
|
||||||
not_present: list[str] = Field(default_factory=list)
|
not_present: list[str] = Field(default_factory=list)
|
||||||
total_tags: list[str] = Field(default_factory=list)
|
total_tags: list[str] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
class TagHistogram(BaseModel):
|
||||||
|
tag_counts: dict[str, int]
|
||||||
|
|||||||
@ -52,6 +52,8 @@ async def parse_multipart_upload(
|
|||||||
user_metadata_raw: str | None = None
|
user_metadata_raw: str | None = None
|
||||||
provided_hash: str | None = None
|
provided_hash: str | None = None
|
||||||
provided_hash_exists: bool | None = None
|
provided_hash_exists: bool | None = None
|
||||||
|
provided_mime_type: str | None = None
|
||||||
|
provided_preview_id: str | None = None
|
||||||
|
|
||||||
file_written = 0
|
file_written = 0
|
||||||
tmp_path: str | None = None
|
tmp_path: str | None = None
|
||||||
@ -128,6 +130,16 @@ async def parse_multipart_upload(
|
|||||||
provided_name = (await field.text()) or None
|
provided_name = (await field.text()) or None
|
||||||
elif fname == "user_metadata":
|
elif fname == "user_metadata":
|
||||||
user_metadata_raw = (await field.text()) or None
|
user_metadata_raw = (await field.text()) or None
|
||||||
|
elif fname == "id":
|
||||||
|
raise UploadError(
|
||||||
|
400,
|
||||||
|
"UNSUPPORTED_FIELD",
|
||||||
|
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
|
||||||
|
)
|
||||||
|
elif fname == "mime_type":
|
||||||
|
provided_mime_type = ((await field.text()) or "").strip() or None
|
||||||
|
elif fname == "preview_id":
|
||||||
|
provided_preview_id = ((await field.text()) or "").strip() or None
|
||||||
|
|
||||||
if not file_present and not (provided_hash and provided_hash_exists):
|
if not file_present and not (provided_hash and provided_hash_exists):
|
||||||
raise UploadError(
|
raise UploadError(
|
||||||
@ -152,6 +164,8 @@ async def parse_multipart_upload(
|
|||||||
user_metadata_raw=user_metadata_raw,
|
user_metadata_raw=user_metadata_raw,
|
||||||
provided_hash=provided_hash,
|
provided_hash=provided_hash,
|
||||||
provided_hash_exists=provided_hash_exists,
|
provided_hash_exists=provided_hash_exists,
|
||||||
|
provided_mime_type=provided_mime_type,
|
||||||
|
provided_preview_id=provided_preview_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -45,13 +45,7 @@ class Asset(Base):
|
|||||||
passive_deletes=True,
|
passive_deletes=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
preview_of: Mapped[list[AssetReference]] = relationship(
|
# preview_id on AssetReference is a self-referential FK to asset_references.id
|
||||||
"AssetReference",
|
|
||||||
back_populates="preview_asset",
|
|
||||||
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
|
|
||||||
foreign_keys=lambda: [AssetReference.preview_id],
|
|
||||||
viewonly=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
__table_args__ = (
|
__table_args__ = (
|
||||||
Index("uq_assets_hash", "hash", unique=True),
|
Index("uq_assets_hash", "hash", unique=True),
|
||||||
@ -91,11 +85,15 @@ class AssetReference(Base):
|
|||||||
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
|
||||||
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
name: Mapped[str] = mapped_column(String(512), nullable=False)
|
||||||
preview_id: Mapped[str | None] = mapped_column(
|
preview_id: Mapped[str | None] = mapped_column(
|
||||||
String(36), ForeignKey("assets.id", ondelete="SET NULL")
|
String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
|
||||||
)
|
)
|
||||||
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||||
JSON(none_as_null=True)
|
JSON(none_as_null=True)
|
||||||
)
|
)
|
||||||
|
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
|
||||||
|
JSON(none_as_null=True), nullable=True, default=None
|
||||||
|
)
|
||||||
|
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
|
||||||
created_at: Mapped[datetime] = mapped_column(
|
created_at: Mapped[datetime] = mapped_column(
|
||||||
DateTime(timezone=False), nullable=False, default=get_utc_now
|
DateTime(timezone=False), nullable=False, default=get_utc_now
|
||||||
)
|
)
|
||||||
@ -115,10 +113,10 @@ class AssetReference(Base):
|
|||||||
foreign_keys=[asset_id],
|
foreign_keys=[asset_id],
|
||||||
lazy="selectin",
|
lazy="selectin",
|
||||||
)
|
)
|
||||||
preview_asset: Mapped[Asset | None] = relationship(
|
preview_ref: Mapped[AssetReference | None] = relationship(
|
||||||
"Asset",
|
"AssetReference",
|
||||||
back_populates="preview_of",
|
|
||||||
foreign_keys=[preview_id],
|
foreign_keys=[preview_id],
|
||||||
|
remote_side=lambda: [AssetReference.id],
|
||||||
)
|
)
|
||||||
|
|
||||||
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
|
||||||
@ -152,6 +150,7 @@ class AssetReference(Base):
|
|||||||
Index("ix_asset_references_created_at", "created_at"),
|
Index("ix_asset_references_created_at", "created_at"),
|
||||||
Index("ix_asset_references_last_access_time", "last_access_time"),
|
Index("ix_asset_references_last_access_time", "last_access_time"),
|
||||||
Index("ix_asset_references_deleted_at", "deleted_at"),
|
Index("ix_asset_references_deleted_at", "deleted_at"),
|
||||||
|
Index("ix_asset_references_preview_id", "preview_id"),
|
||||||
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
Index("ix_asset_references_owner_name", "owner_id", "name"),
|
||||||
CheckConstraint(
|
CheckConstraint(
|
||||||
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
|
||||||
@ -192,6 +191,10 @@ class AssetReferenceMeta(Base):
|
|||||||
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
|
||||||
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
|
||||||
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
|
||||||
|
CheckConstraint(
|
||||||
|
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
|
||||||
|
name="has_value",
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -31,16 +31,21 @@ from app.assets.database.queries.asset_reference import (
|
|||||||
get_unenriched_references,
|
get_unenriched_references,
|
||||||
get_unreferenced_unhashed_asset_ids,
|
get_unreferenced_unhashed_asset_ids,
|
||||||
insert_reference,
|
insert_reference,
|
||||||
|
list_all_file_paths_by_asset_id,
|
||||||
list_references_by_asset_id,
|
list_references_by_asset_id,
|
||||||
list_references_page,
|
list_references_page,
|
||||||
mark_references_missing_outside_prefixes,
|
mark_references_missing_outside_prefixes,
|
||||||
|
rebuild_metadata_projection,
|
||||||
|
reference_exists,
|
||||||
reference_exists_for_asset_id,
|
reference_exists_for_asset_id,
|
||||||
restore_references_by_paths,
|
restore_references_by_paths,
|
||||||
set_reference_metadata,
|
set_reference_metadata,
|
||||||
set_reference_preview,
|
set_reference_preview,
|
||||||
|
set_reference_system_metadata,
|
||||||
soft_delete_reference_by_id,
|
soft_delete_reference_by_id,
|
||||||
update_reference_access_time,
|
update_reference_access_time,
|
||||||
update_reference_name,
|
update_reference_name,
|
||||||
|
update_is_missing_by_asset_id,
|
||||||
update_reference_timestamps,
|
update_reference_timestamps,
|
||||||
update_reference_updated_at,
|
update_reference_updated_at,
|
||||||
upsert_reference,
|
upsert_reference,
|
||||||
@ -54,6 +59,7 @@ from app.assets.database.queries.tags import (
|
|||||||
bulk_insert_tags_and_meta,
|
bulk_insert_tags_and_meta,
|
||||||
ensure_tags_exist,
|
ensure_tags_exist,
|
||||||
get_reference_tags,
|
get_reference_tags,
|
||||||
|
list_tag_counts_for_filtered_assets,
|
||||||
list_tags_with_usage,
|
list_tags_with_usage,
|
||||||
remove_missing_tag_for_asset_id,
|
remove_missing_tag_for_asset_id,
|
||||||
remove_tags_from_reference,
|
remove_tags_from_reference,
|
||||||
@ -97,20 +103,26 @@ __all__ = [
|
|||||||
"get_unenriched_references",
|
"get_unenriched_references",
|
||||||
"get_unreferenced_unhashed_asset_ids",
|
"get_unreferenced_unhashed_asset_ids",
|
||||||
"insert_reference",
|
"insert_reference",
|
||||||
|
"list_all_file_paths_by_asset_id",
|
||||||
"list_references_by_asset_id",
|
"list_references_by_asset_id",
|
||||||
"list_references_page",
|
"list_references_page",
|
||||||
|
"list_tag_counts_for_filtered_assets",
|
||||||
"list_tags_with_usage",
|
"list_tags_with_usage",
|
||||||
"mark_references_missing_outside_prefixes",
|
"mark_references_missing_outside_prefixes",
|
||||||
"reassign_asset_references",
|
"reassign_asset_references",
|
||||||
|
"rebuild_metadata_projection",
|
||||||
|
"reference_exists",
|
||||||
"reference_exists_for_asset_id",
|
"reference_exists_for_asset_id",
|
||||||
"remove_missing_tag_for_asset_id",
|
"remove_missing_tag_for_asset_id",
|
||||||
"remove_tags_from_reference",
|
"remove_tags_from_reference",
|
||||||
"restore_references_by_paths",
|
"restore_references_by_paths",
|
||||||
"set_reference_metadata",
|
"set_reference_metadata",
|
||||||
"set_reference_preview",
|
"set_reference_preview",
|
||||||
|
"set_reference_system_metadata",
|
||||||
"soft_delete_reference_by_id",
|
"soft_delete_reference_by_id",
|
||||||
"set_reference_tags",
|
"set_reference_tags",
|
||||||
"update_asset_hash_and_mime",
|
"update_asset_hash_and_mime",
|
||||||
|
"update_is_missing_by_asset_id",
|
||||||
"update_reference_access_time",
|
"update_reference_access_time",
|
||||||
"update_reference_name",
|
"update_reference_name",
|
||||||
"update_reference_timestamps",
|
"update_reference_timestamps",
|
||||||
|
|||||||
@ -69,7 +69,7 @@ def upsert_asset(
|
|||||||
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
|
||||||
asset.size_bytes = int(size_bytes)
|
asset.size_bytes = int(size_bytes)
|
||||||
changed = True
|
changed = True
|
||||||
if mime_type and asset.mime_type != mime_type:
|
if mime_type and not asset.mime_type:
|
||||||
asset.mime_type = mime_type
|
asset.mime_type = mime_type
|
||||||
changed = True
|
changed = True
|
||||||
if changed:
|
if changed:
|
||||||
@ -118,7 +118,7 @@ def update_asset_hash_and_mime(
|
|||||||
return False
|
return False
|
||||||
if asset_hash is not None:
|
if asset_hash is not None:
|
||||||
asset.hash = asset_hash
|
asset.hash = asset_hash
|
||||||
if mime_type is not None:
|
if mime_type is not None and not asset.mime_type:
|
||||||
asset.mime_type = mime_type
|
asset.mime_type = mime_type
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|||||||
@ -10,7 +10,7 @@ from decimal import Decimal
|
|||||||
from typing import NamedTuple, Sequence
|
from typing import NamedTuple, Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
from sqlalchemy import delete, exists, select
|
from sqlalchemy import delete, select
|
||||||
from sqlalchemy.dialects import sqlite
|
from sqlalchemy.dialects import sqlite
|
||||||
from sqlalchemy.exc import IntegrityError
|
from sqlalchemy.exc import IntegrityError
|
||||||
from sqlalchemy.orm import Session, noload
|
from sqlalchemy.orm import Session, noload
|
||||||
@ -24,12 +24,14 @@ from app.assets.database.models import (
|
|||||||
)
|
)
|
||||||
from app.assets.database.queries.common import (
|
from app.assets.database.queries.common import (
|
||||||
MAX_BIND_PARAMS,
|
MAX_BIND_PARAMS,
|
||||||
|
apply_metadata_filter,
|
||||||
|
apply_tag_filters,
|
||||||
build_prefix_like_conditions,
|
build_prefix_like_conditions,
|
||||||
build_visible_owner_clause,
|
build_visible_owner_clause,
|
||||||
calculate_rows_per_statement,
|
calculate_rows_per_statement,
|
||||||
iter_chunks,
|
iter_chunks,
|
||||||
)
|
)
|
||||||
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags
|
from app.assets.helpers import escape_sql_like_string, get_utc_now
|
||||||
|
|
||||||
|
|
||||||
def _check_is_scalar(v):
|
def _check_is_scalar(v):
|
||||||
@ -44,15 +46,6 @@ def _check_is_scalar(v):
|
|||||||
|
|
||||||
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
||||||
"""Convert a scalar value to a typed projection row."""
|
"""Convert a scalar value to a typed projection row."""
|
||||||
if value is None:
|
|
||||||
return {
|
|
||||||
"key": key,
|
|
||||||
"ordinal": ordinal,
|
|
||||||
"val_str": None,
|
|
||||||
"val_num": None,
|
|
||||||
"val_bool": None,
|
|
||||||
"val_json": None,
|
|
||||||
}
|
|
||||||
if isinstance(value, bool):
|
if isinstance(value, bool):
|
||||||
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
|
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
|
||||||
if isinstance(value, (int, float, Decimal)):
|
if isinstance(value, (int, float, Decimal)):
|
||||||
@ -66,96 +59,19 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict:
|
|||||||
def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
def convert_metadata_to_rows(key: str, value) -> list[dict]:
|
||||||
"""Turn a metadata key/value into typed projection rows."""
|
"""Turn a metadata key/value into typed projection rows."""
|
||||||
if value is None:
|
if value is None:
|
||||||
return [_scalar_to_row(key, 0, None)]
|
return []
|
||||||
|
|
||||||
if _check_is_scalar(value):
|
if _check_is_scalar(value):
|
||||||
return [_scalar_to_row(key, 0, value)]
|
return [_scalar_to_row(key, 0, value)]
|
||||||
|
|
||||||
if isinstance(value, list):
|
if isinstance(value, list):
|
||||||
if all(_check_is_scalar(x) for x in value):
|
if all(_check_is_scalar(x) for x in value):
|
||||||
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)]
|
return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None]
|
||||||
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)]
|
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None]
|
||||||
|
|
||||||
return [{"key": key, "ordinal": 0, "val_json": value}]
|
return [{"key": key, "ordinal": 0, "val_json": value}]
|
||||||
|
|
||||||
|
|
||||||
def _apply_tag_filters(
|
|
||||||
stmt: sa.sql.Select,
|
|
||||||
include_tags: Sequence[str] | None = None,
|
|
||||||
exclude_tags: Sequence[str] | None = None,
|
|
||||||
) -> sa.sql.Select:
|
|
||||||
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
|
||||||
include_tags = normalize_tags(include_tags)
|
|
||||||
exclude_tags = normalize_tags(exclude_tags)
|
|
||||||
|
|
||||||
if include_tags:
|
|
||||||
for tag_name in include_tags:
|
|
||||||
stmt = stmt.where(
|
|
||||||
exists().where(
|
|
||||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
|
||||||
& (AssetReferenceTag.tag_name == tag_name)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if exclude_tags:
|
|
||||||
stmt = stmt.where(
|
|
||||||
~exists().where(
|
|
||||||
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
|
||||||
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
|
||||||
)
|
|
||||||
)
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
|
|
||||||
def _apply_metadata_filter(
|
|
||||||
stmt: sa.sql.Select,
|
|
||||||
metadata_filter: dict | None = None,
|
|
||||||
) -> sa.sql.Select:
|
|
||||||
"""Apply filters using asset_reference_meta projection table."""
|
|
||||||
if not metadata_filter:
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
|
||||||
return sa.exists().where(
|
|
||||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
|
||||||
AssetReferenceMeta.key == key,
|
|
||||||
*preds,
|
|
||||||
)
|
|
||||||
|
|
||||||
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
|
||||||
if value is None:
|
|
||||||
no_row_for_key = sa.not_(
|
|
||||||
sa.exists().where(
|
|
||||||
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
|
||||||
AssetReferenceMeta.key == key,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
null_row = _exists_for_pred(
|
|
||||||
key,
|
|
||||||
AssetReferenceMeta.val_json.is_(None),
|
|
||||||
AssetReferenceMeta.val_str.is_(None),
|
|
||||||
AssetReferenceMeta.val_num.is_(None),
|
|
||||||
AssetReferenceMeta.val_bool.is_(None),
|
|
||||||
)
|
|
||||||
return sa.or_(no_row_for_key, null_row)
|
|
||||||
|
|
||||||
if isinstance(value, bool):
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
|
||||||
if isinstance(value, (int, float, Decimal)):
|
|
||||||
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
|
||||||
if isinstance(value, str):
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
|
||||||
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
|
||||||
|
|
||||||
for k, v in metadata_filter.items():
|
|
||||||
if isinstance(v, list):
|
|
||||||
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
|
||||||
if ors:
|
|
||||||
stmt = stmt.where(sa.or_(*ors))
|
|
||||||
else:
|
|
||||||
stmt = stmt.where(_exists_clause_for_value(k, v))
|
|
||||||
return stmt
|
|
||||||
|
|
||||||
|
|
||||||
def get_reference_by_id(
|
def get_reference_by_id(
|
||||||
@ -212,6 +128,21 @@ def reference_exists_for_asset_id(
|
|||||||
return session.execute(q).first() is not None
|
return session.execute(q).first() is not None
|
||||||
|
|
||||||
|
|
||||||
|
def reference_exists(
|
||||||
|
session: Session,
|
||||||
|
reference_id: str,
|
||||||
|
) -> bool:
|
||||||
|
"""Return True if a reference with the given ID exists (not soft-deleted)."""
|
||||||
|
q = (
|
||||||
|
select(sa.literal(True))
|
||||||
|
.select_from(AssetReference)
|
||||||
|
.where(AssetReference.id == reference_id)
|
||||||
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
|
.limit(1)
|
||||||
|
)
|
||||||
|
return session.execute(q).first() is not None
|
||||||
|
|
||||||
|
|
||||||
def insert_reference(
|
def insert_reference(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
@ -336,8 +267,8 @@ def list_references_page(
|
|||||||
escaped, esc = escape_sql_like_string(name_contains)
|
escaped, esc = escape_sql_like_string(name_contains)
|
||||||
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||||
|
|
||||||
base = _apply_tag_filters(base, include_tags, exclude_tags)
|
base = apply_tag_filters(base, include_tags, exclude_tags)
|
||||||
base = _apply_metadata_filter(base, metadata_filter)
|
base = apply_metadata_filter(base, metadata_filter)
|
||||||
|
|
||||||
sort = (sort or "created_at").lower()
|
sort = (sort or "created_at").lower()
|
||||||
order = (order or "desc").lower()
|
order = (order or "desc").lower()
|
||||||
@ -366,8 +297,8 @@ def list_references_page(
|
|||||||
count_stmt = count_stmt.where(
|
count_stmt = count_stmt.where(
|
||||||
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
AssetReference.name.ilike(f"%{escaped}%", escape=esc)
|
||||||
)
|
)
|
||||||
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
|
||||||
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter)
|
count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
|
||||||
|
|
||||||
total = int(session.execute(count_stmt).scalar_one() or 0)
|
total = int(session.execute(count_stmt).scalar_one() or 0)
|
||||||
refs = session.execute(base).unique().scalars().all()
|
refs = session.execute(base).unique().scalars().all()
|
||||||
@ -379,7 +310,7 @@ def list_references_page(
|
|||||||
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
select(AssetReferenceTag.asset_reference_id, Tag.name)
|
||||||
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
.join(Tag, Tag.name == AssetReferenceTag.tag_name)
|
||||||
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
.where(AssetReferenceTag.asset_reference_id.in_(id_list))
|
||||||
.order_by(AssetReferenceTag.added_at)
|
.order_by(AssetReferenceTag.tag_name.asc())
|
||||||
)
|
)
|
||||||
for ref_id, tag_name in rows.all():
|
for ref_id, tag_name in rows.all():
|
||||||
tag_map[ref_id].append(tag_name)
|
tag_map[ref_id].append(tag_name)
|
||||||
@ -492,6 +423,42 @@ def update_reference_updated_at(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None:
|
||||||
|
"""Delete and rebuild AssetReferenceMeta rows from merged system+user metadata.
|
||||||
|
|
||||||
|
The merged dict is ``{**system_metadata, **user_metadata}`` so user keys
|
||||||
|
override system keys of the same name.
|
||||||
|
"""
|
||||||
|
session.execute(
|
||||||
|
delete(AssetReferenceMeta).where(
|
||||||
|
AssetReferenceMeta.asset_reference_id == ref.id
|
||||||
|
)
|
||||||
|
)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})}
|
||||||
|
if not merged:
|
||||||
|
return
|
||||||
|
|
||||||
|
rows: list[AssetReferenceMeta] = []
|
||||||
|
for k, v in merged.items():
|
||||||
|
for r in convert_metadata_to_rows(k, v):
|
||||||
|
rows.append(
|
||||||
|
AssetReferenceMeta(
|
||||||
|
asset_reference_id=ref.id,
|
||||||
|
key=r["key"],
|
||||||
|
ordinal=int(r["ordinal"]),
|
||||||
|
val_str=r.get("val_str"),
|
||||||
|
val_num=r.get("val_num"),
|
||||||
|
val_bool=r.get("val_bool"),
|
||||||
|
val_json=r.get("val_json"),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
if rows:
|
||||||
|
session.add_all(rows)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
|
||||||
def set_reference_metadata(
|
def set_reference_metadata(
|
||||||
session: Session,
|
session: Session,
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
@ -505,33 +472,24 @@ def set_reference_metadata(
|
|||||||
ref.updated_at = get_utc_now()
|
ref.updated_at = get_utc_now()
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
session.execute(
|
rebuild_metadata_projection(session, ref)
|
||||||
delete(AssetReferenceMeta).where(
|
|
||||||
AssetReferenceMeta.asset_reference_id == reference_id
|
|
||||||
)
|
def set_reference_system_metadata(
|
||||||
)
|
session: Session,
|
||||||
|
reference_id: str,
|
||||||
|
system_metadata: dict | None = None,
|
||||||
|
) -> None:
|
||||||
|
"""Set system_metadata on a reference and rebuild the merged projection."""
|
||||||
|
ref = session.get(AssetReference, reference_id)
|
||||||
|
if not ref:
|
||||||
|
raise ValueError(f"AssetReference {reference_id} not found")
|
||||||
|
|
||||||
|
ref.system_metadata = system_metadata or {}
|
||||||
|
ref.updated_at = get_utc_now()
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
if not user_metadata:
|
rebuild_metadata_projection(session, ref)
|
||||||
return
|
|
||||||
|
|
||||||
rows: list[AssetReferenceMeta] = []
|
|
||||||
for k, v in user_metadata.items():
|
|
||||||
for r in convert_metadata_to_rows(k, v):
|
|
||||||
rows.append(
|
|
||||||
AssetReferenceMeta(
|
|
||||||
asset_reference_id=reference_id,
|
|
||||||
key=r["key"],
|
|
||||||
ordinal=int(r["ordinal"]),
|
|
||||||
val_str=r.get("val_str"),
|
|
||||||
val_num=r.get("val_num"),
|
|
||||||
val_bool=r.get("val_bool"),
|
|
||||||
val_json=r.get("val_json"),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
if rows:
|
|
||||||
session.add_all(rows)
|
|
||||||
session.flush()
|
|
||||||
|
|
||||||
|
|
||||||
def delete_reference_by_id(
|
def delete_reference_by_id(
|
||||||
@ -571,19 +529,19 @@ def soft_delete_reference_by_id(
|
|||||||
def set_reference_preview(
|
def set_reference_preview(
|
||||||
session: Session,
|
session: Session,
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
preview_asset_id: str | None = None,
|
preview_reference_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
|
||||||
ref = session.get(AssetReference, reference_id)
|
ref = session.get(AssetReference, reference_id)
|
||||||
if not ref:
|
if not ref:
|
||||||
raise ValueError(f"AssetReference {reference_id} not found")
|
raise ValueError(f"AssetReference {reference_id} not found")
|
||||||
|
|
||||||
if preview_asset_id is None:
|
if preview_reference_id is None:
|
||||||
ref.preview_id = None
|
ref.preview_id = None
|
||||||
else:
|
else:
|
||||||
if not session.get(Asset, preview_asset_id):
|
if not session.get(AssetReference, preview_reference_id):
|
||||||
raise ValueError(f"Preview Asset {preview_asset_id} not found")
|
raise ValueError(f"Preview AssetReference {preview_reference_id} not found")
|
||||||
ref.preview_id = preview_asset_id
|
ref.preview_id = preview_reference_id
|
||||||
|
|
||||||
ref.updated_at = get_utc_now()
|
ref.updated_at = get_utc_now()
|
||||||
session.flush()
|
session.flush()
|
||||||
@ -609,6 +567,8 @@ def list_references_by_asset_id(
|
|||||||
session.execute(
|
session.execute(
|
||||||
select(AssetReference)
|
select(AssetReference)
|
||||||
.where(AssetReference.asset_id == asset_id)
|
.where(AssetReference.asset_id == asset_id)
|
||||||
|
.where(AssetReference.is_missing == False) # noqa: E712
|
||||||
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
.order_by(AssetReference.id.asc())
|
.order_by(AssetReference.id.asc())
|
||||||
)
|
)
|
||||||
.scalars()
|
.scalars()
|
||||||
@ -616,6 +576,25 @@ def list_references_by_asset_id(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def list_all_file_paths_by_asset_id(
|
||||||
|
session: Session,
|
||||||
|
asset_id: str,
|
||||||
|
) -> list[str]:
|
||||||
|
"""Return every file_path for an asset, including soft-deleted/missing refs.
|
||||||
|
|
||||||
|
Used for orphan cleanup where all on-disk files must be removed.
|
||||||
|
"""
|
||||||
|
return list(
|
||||||
|
session.execute(
|
||||||
|
select(AssetReference.file_path)
|
||||||
|
.where(AssetReference.asset_id == asset_id)
|
||||||
|
.where(AssetReference.file_path.isnot(None))
|
||||||
|
)
|
||||||
|
.scalars()
|
||||||
|
.all()
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def upsert_reference(
|
def upsert_reference(
|
||||||
session: Session,
|
session: Session,
|
||||||
asset_id: str,
|
asset_id: str,
|
||||||
@ -855,6 +834,22 @@ def bulk_update_is_missing(
|
|||||||
return total
|
return total
|
||||||
|
|
||||||
|
|
||||||
|
def update_is_missing_by_asset_id(
|
||||||
|
session: Session, asset_id: str, value: bool
|
||||||
|
) -> int:
|
||||||
|
"""Set is_missing flag for ALL references belonging to an asset.
|
||||||
|
|
||||||
|
Returns: Number of rows updated
|
||||||
|
"""
|
||||||
|
result = session.execute(
|
||||||
|
sa.update(AssetReference)
|
||||||
|
.where(AssetReference.asset_id == asset_id)
|
||||||
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
|
.values(is_missing=value)
|
||||||
|
)
|
||||||
|
return result.rowcount
|
||||||
|
|
||||||
|
|
||||||
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
|
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
|
||||||
"""Delete references by their IDs.
|
"""Delete references by their IDs.
|
||||||
|
|
||||||
|
|||||||
@ -1,12 +1,14 @@
|
|||||||
"""Shared utilities for database query modules."""
|
"""Shared utilities for database query modules."""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Iterable
|
from decimal import Decimal
|
||||||
|
from typing import Iterable, Sequence
|
||||||
|
|
||||||
import sqlalchemy as sa
|
import sqlalchemy as sa
|
||||||
|
from sqlalchemy import exists
|
||||||
|
|
||||||
from app.assets.database.models import AssetReference
|
from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
|
||||||
from app.assets.helpers import escape_sql_like_string
|
from app.assets.helpers import escape_sql_like_string, normalize_tags
|
||||||
|
|
||||||
MAX_BIND_PARAMS = 800
|
MAX_BIND_PARAMS = 800
|
||||||
|
|
||||||
@ -52,3 +54,74 @@ def build_prefix_like_conditions(
|
|||||||
escaped, esc = escape_sql_like_string(base)
|
escaped, esc = escape_sql_like_string(base)
|
||||||
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
|
||||||
return conds
|
return conds
|
||||||
|
|
||||||
|
|
||||||
|
def apply_tag_filters(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""include_tags: every tag must be present; exclude_tags: none may be present."""
|
||||||
|
include_tags = normalize_tags(include_tags)
|
||||||
|
exclude_tags = normalize_tags(exclude_tags)
|
||||||
|
|
||||||
|
if include_tags:
|
||||||
|
for tag_name in include_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
exists().where(
|
||||||
|
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||||
|
& (AssetReferenceTag.tag_name == tag_name)
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if exclude_tags:
|
||||||
|
stmt = stmt.where(
|
||||||
|
~exists().where(
|
||||||
|
(AssetReferenceTag.asset_reference_id == AssetReference.id)
|
||||||
|
& (AssetReferenceTag.tag_name.in_(exclude_tags))
|
||||||
|
)
|
||||||
|
)
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
|
||||||
|
def apply_metadata_filter(
|
||||||
|
stmt: sa.sql.Select,
|
||||||
|
metadata_filter: dict | None = None,
|
||||||
|
) -> sa.sql.Select:
|
||||||
|
"""Apply filters using asset_reference_meta projection table."""
|
||||||
|
if not metadata_filter:
|
||||||
|
return stmt
|
||||||
|
|
||||||
|
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
|
||||||
|
return sa.exists().where(
|
||||||
|
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||||
|
AssetReferenceMeta.key == key,
|
||||||
|
*preds,
|
||||||
|
)
|
||||||
|
|
||||||
|
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
|
||||||
|
if value is None:
|
||||||
|
return sa.not_(
|
||||||
|
sa.exists().where(
|
||||||
|
AssetReferenceMeta.asset_reference_id == AssetReference.id,
|
||||||
|
AssetReferenceMeta.key == key,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(value, bool):
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
|
||||||
|
if isinstance(value, (int, float, Decimal)):
|
||||||
|
num = value if isinstance(value, Decimal) else Decimal(str(value))
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
|
||||||
|
if isinstance(value, str):
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
|
||||||
|
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
|
||||||
|
|
||||||
|
for k, v in metadata_filter.items():
|
||||||
|
if isinstance(v, list):
|
||||||
|
ors = [_exists_clause_for_value(k, elem) for elem in v]
|
||||||
|
if ors:
|
||||||
|
stmt = stmt.where(sa.or_(*ors))
|
||||||
|
else:
|
||||||
|
stmt = stmt.where(_exists_clause_for_value(k, v))
|
||||||
|
return stmt
|
||||||
|
|||||||
@ -8,12 +8,15 @@ from sqlalchemy.exc import IntegrityError
|
|||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
from app.assets.database.models import (
|
from app.assets.database.models import (
|
||||||
|
Asset,
|
||||||
AssetReference,
|
AssetReference,
|
||||||
AssetReferenceMeta,
|
AssetReferenceMeta,
|
||||||
AssetReferenceTag,
|
AssetReferenceTag,
|
||||||
Tag,
|
Tag,
|
||||||
)
|
)
|
||||||
from app.assets.database.queries.common import (
|
from app.assets.database.queries.common import (
|
||||||
|
apply_metadata_filter,
|
||||||
|
apply_tag_filters,
|
||||||
build_visible_owner_clause,
|
build_visible_owner_clause,
|
||||||
iter_row_chunks,
|
iter_row_chunks,
|
||||||
)
|
)
|
||||||
@ -72,9 +75,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
|
|||||||
tag_name
|
tag_name
|
||||||
for (tag_name,) in (
|
for (tag_name,) in (
|
||||||
session.execute(
|
session.execute(
|
||||||
select(AssetReferenceTag.tag_name).where(
|
select(AssetReferenceTag.tag_name)
|
||||||
AssetReferenceTag.asset_reference_id == reference_id
|
.where(AssetReferenceTag.asset_reference_id == reference_id)
|
||||||
)
|
.order_by(AssetReferenceTag.tag_name.asc())
|
||||||
)
|
)
|
||||||
).all()
|
).all()
|
||||||
]
|
]
|
||||||
@ -117,7 +120,7 @@ def set_reference_tags(
|
|||||||
)
|
)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
return SetTagsResult(added=to_add, removed=to_remove, total=desired)
|
return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
|
||||||
|
|
||||||
|
|
||||||
def add_tags_to_reference(
|
def add_tags_to_reference(
|
||||||
@ -272,6 +275,12 @@ def list_tags_with_usage(
|
|||||||
.select_from(AssetReferenceTag)
|
.select_from(AssetReferenceTag)
|
||||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||||
.where(build_visible_owner_clause(owner_id))
|
.where(build_visible_owner_clause(owner_id))
|
||||||
|
.where(
|
||||||
|
sa.or_(
|
||||||
|
AssetReference.is_missing == False, # noqa: E712
|
||||||
|
AssetReferenceTag.tag_name == "missing",
|
||||||
|
)
|
||||||
|
)
|
||||||
.where(AssetReference.deleted_at.is_(None))
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
.group_by(AssetReferenceTag.tag_name)
|
.group_by(AssetReferenceTag.tag_name)
|
||||||
.subquery()
|
.subquery()
|
||||||
@ -308,6 +317,12 @@ def list_tags_with_usage(
|
|||||||
select(AssetReferenceTag.tag_name)
|
select(AssetReferenceTag.tag_name)
|
||||||
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
|
||||||
.where(build_visible_owner_clause(owner_id))
|
.where(build_visible_owner_clause(owner_id))
|
||||||
|
.where(
|
||||||
|
sa.or_(
|
||||||
|
AssetReference.is_missing == False, # noqa: E712
|
||||||
|
AssetReferenceTag.tag_name == "missing",
|
||||||
|
)
|
||||||
|
)
|
||||||
.where(AssetReference.deleted_at.is_(None))
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
.group_by(AssetReferenceTag.tag_name)
|
.group_by(AssetReferenceTag.tag_name)
|
||||||
)
|
)
|
||||||
@ -320,6 +335,53 @@ def list_tags_with_usage(
|
|||||||
return rows_norm, int(total or 0)
|
return rows_norm, int(total or 0)
|
||||||
|
|
||||||
|
|
||||||
|
def list_tag_counts_for_filtered_assets(
|
||||||
|
session: Session,
|
||||||
|
owner_id: str = "",
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
name_contains: str | None = None,
|
||||||
|
metadata_filter: dict | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> dict[str, int]:
|
||||||
|
"""Return tag counts for assets matching the given filters.
|
||||||
|
|
||||||
|
Uses the same filtering logic as list_references_page but returns
|
||||||
|
{tag_name: count} instead of paginated references.
|
||||||
|
"""
|
||||||
|
# Build a subquery of matching reference IDs
|
||||||
|
ref_sq = (
|
||||||
|
select(AssetReference.id)
|
||||||
|
.join(Asset, Asset.id == AssetReference.asset_id)
|
||||||
|
.where(build_visible_owner_clause(owner_id))
|
||||||
|
.where(AssetReference.is_missing == False) # noqa: E712
|
||||||
|
.where(AssetReference.deleted_at.is_(None))
|
||||||
|
)
|
||||||
|
|
||||||
|
if name_contains:
|
||||||
|
escaped, esc = escape_sql_like_string(name_contains)
|
||||||
|
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
|
||||||
|
|
||||||
|
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
|
||||||
|
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
|
||||||
|
ref_sq = ref_sq.subquery()
|
||||||
|
|
||||||
|
# Count tags across those references
|
||||||
|
q = (
|
||||||
|
select(
|
||||||
|
AssetReferenceTag.tag_name,
|
||||||
|
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
|
||||||
|
)
|
||||||
|
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
|
||||||
|
.group_by(AssetReferenceTag.tag_name)
|
||||||
|
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
|
||||||
|
.limit(limit)
|
||||||
|
)
|
||||||
|
|
||||||
|
rows = session.execute(q).all()
|
||||||
|
return {tag_name: int(cnt) for tag_name, cnt in rows}
|
||||||
|
|
||||||
|
|
||||||
def bulk_insert_tags_and_meta(
|
def bulk_insert_tags_and_meta(
|
||||||
session: Session,
|
session: Session,
|
||||||
tag_rows: list[dict],
|
tag_rows: list[dict],
|
||||||
|
|||||||
@ -18,7 +18,7 @@ from app.assets.database.queries import (
|
|||||||
mark_references_missing_outside_prefixes,
|
mark_references_missing_outside_prefixes,
|
||||||
reassign_asset_references,
|
reassign_asset_references,
|
||||||
remove_missing_tag_for_asset_id,
|
remove_missing_tag_for_asset_id,
|
||||||
set_reference_metadata,
|
set_reference_system_metadata,
|
||||||
update_asset_hash_and_mime,
|
update_asset_hash_and_mime,
|
||||||
)
|
)
|
||||||
from app.assets.services.bulk_ingest import (
|
from app.assets.services.bulk_ingest import (
|
||||||
@ -490,8 +490,8 @@ def enrich_asset(
|
|||||||
logging.warning("Failed to hash %s: %s", file_path, e)
|
logging.warning("Failed to hash %s: %s", file_path, e)
|
||||||
|
|
||||||
if extract_metadata and metadata:
|
if extract_metadata and metadata:
|
||||||
user_metadata = metadata.to_user_metadata()
|
system_metadata = metadata.to_user_metadata()
|
||||||
set_reference_metadata(session, reference_id, user_metadata)
|
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||||
|
|
||||||
if full_hash:
|
if full_hash:
|
||||||
existing = get_asset_by_hash(session, full_hash)
|
existing = get_asset_by_hash(session, full_hash)
|
||||||
|
|||||||
@ -16,10 +16,12 @@ from app.assets.database.queries import (
|
|||||||
get_reference_by_id,
|
get_reference_by_id,
|
||||||
get_reference_with_owner_check,
|
get_reference_with_owner_check,
|
||||||
list_references_page,
|
list_references_page,
|
||||||
|
list_all_file_paths_by_asset_id,
|
||||||
list_references_by_asset_id,
|
list_references_by_asset_id,
|
||||||
set_reference_metadata,
|
set_reference_metadata,
|
||||||
set_reference_preview,
|
set_reference_preview,
|
||||||
set_reference_tags,
|
set_reference_tags,
|
||||||
|
update_asset_hash_and_mime,
|
||||||
update_reference_access_time,
|
update_reference_access_time,
|
||||||
update_reference_name,
|
update_reference_name,
|
||||||
update_reference_updated_at,
|
update_reference_updated_at,
|
||||||
@ -67,6 +69,8 @@ def update_asset_metadata(
|
|||||||
user_metadata: UserMetadata = None,
|
user_metadata: UserMetadata = None,
|
||||||
tag_origin: str = "manual",
|
tag_origin: str = "manual",
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
|
mime_type: str | None = None,
|
||||||
|
preview_id: str | None = None,
|
||||||
) -> AssetDetailResult:
|
) -> AssetDetailResult:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
ref = get_reference_with_owner_check(session, reference_id, owner_id)
|
||||||
@ -103,6 +107,21 @@ def update_asset_metadata(
|
|||||||
)
|
)
|
||||||
touched = True
|
touched = True
|
||||||
|
|
||||||
|
if mime_type is not None:
|
||||||
|
updated = update_asset_hash_and_mime(
|
||||||
|
session, asset_id=ref.asset_id, mime_type=mime_type
|
||||||
|
)
|
||||||
|
if updated:
|
||||||
|
touched = True
|
||||||
|
|
||||||
|
if preview_id is not None:
|
||||||
|
set_reference_preview(
|
||||||
|
session,
|
||||||
|
reference_id=reference_id,
|
||||||
|
preview_reference_id=preview_id,
|
||||||
|
)
|
||||||
|
touched = True
|
||||||
|
|
||||||
if touched and user_metadata is None:
|
if touched and user_metadata is None:
|
||||||
update_reference_updated_at(session, reference_id=reference_id)
|
update_reference_updated_at(session, reference_id=reference_id)
|
||||||
|
|
||||||
@ -159,11 +178,9 @@ def delete_asset_reference(
|
|||||||
session.commit()
|
session.commit()
|
||||||
return True
|
return True
|
||||||
|
|
||||||
# Orphaned asset - delete it and its files
|
# Orphaned asset - gather ALL file paths (including
|
||||||
refs = list_references_by_asset_id(session, asset_id=asset_id)
|
# soft-deleted / missing refs) so their on-disk files get cleaned up.
|
||||||
file_paths = [
|
file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
|
||||||
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
|
|
||||||
]
|
|
||||||
# Also include the just-deleted file path
|
# Also include the just-deleted file path
|
||||||
if file_path:
|
if file_path:
|
||||||
file_paths.append(file_path)
|
file_paths.append(file_path)
|
||||||
@ -185,7 +202,7 @@ def delete_asset_reference(
|
|||||||
|
|
||||||
def set_asset_preview(
|
def set_asset_preview(
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
preview_asset_id: str | None = None,
|
preview_reference_id: str | None = None,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
) -> AssetDetailResult:
|
) -> AssetDetailResult:
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
@ -194,7 +211,7 @@ def set_asset_preview(
|
|||||||
set_reference_preview(
|
set_reference_preview(
|
||||||
session,
|
session,
|
||||||
reference_id=reference_id,
|
reference_id=reference_id,
|
||||||
preview_asset_id=preview_asset_id,
|
preview_reference_id=preview_reference_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
result = fetch_reference_asset_and_tags(
|
result = fetch_reference_asset_and_tags(
|
||||||
@ -263,6 +280,47 @@ def list_assets_page(
|
|||||||
return ListAssetsResult(items=items, total=total)
|
return ListAssetsResult(items=items, total=total)
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_hash_to_path(
|
||||||
|
asset_hash: str,
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> DownloadResolutionResult | None:
|
||||||
|
"""Resolve a blake3 hash to an on-disk file path.
|
||||||
|
|
||||||
|
Only references visible to *owner_id* are considered (owner-less
|
||||||
|
references are always visible).
|
||||||
|
|
||||||
|
Returns a DownloadResolutionResult with abs_path, content_type, and
|
||||||
|
download_name, or None if no asset or live path is found.
|
||||||
|
"""
|
||||||
|
with create_session() as session:
|
||||||
|
asset = queries_get_asset_by_hash(session, asset_hash)
|
||||||
|
if not asset:
|
||||||
|
return None
|
||||||
|
refs = list_references_by_asset_id(session, asset_id=asset.id)
|
||||||
|
visible = [
|
||||||
|
r for r in refs
|
||||||
|
if r.owner_id == "" or r.owner_id == owner_id
|
||||||
|
]
|
||||||
|
abs_path = select_best_live_path(visible)
|
||||||
|
if not abs_path:
|
||||||
|
return None
|
||||||
|
display_name = os.path.basename(abs_path)
|
||||||
|
for ref in visible:
|
||||||
|
if ref.file_path == abs_path and ref.name:
|
||||||
|
display_name = ref.name
|
||||||
|
break
|
||||||
|
ctype = (
|
||||||
|
asset.mime_type
|
||||||
|
or mimetypes.guess_type(display_name)[0]
|
||||||
|
or "application/octet-stream"
|
||||||
|
)
|
||||||
|
return DownloadResolutionResult(
|
||||||
|
abs_path=abs_path,
|
||||||
|
content_type=ctype,
|
||||||
|
download_name=display_name,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def resolve_asset_for_download(
|
def resolve_asset_for_download(
|
||||||
reference_id: str,
|
reference_id: str,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
|
|||||||
@ -11,13 +11,14 @@ from app.assets.database.queries import (
|
|||||||
add_tags_to_reference,
|
add_tags_to_reference,
|
||||||
fetch_reference_and_asset,
|
fetch_reference_and_asset,
|
||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
get_existing_asset_ids,
|
|
||||||
get_reference_by_file_path,
|
get_reference_by_file_path,
|
||||||
get_reference_tags,
|
get_reference_tags,
|
||||||
get_or_create_reference,
|
get_or_create_reference,
|
||||||
|
reference_exists,
|
||||||
remove_missing_tag_for_asset_id,
|
remove_missing_tag_for_asset_id,
|
||||||
set_reference_metadata,
|
set_reference_metadata,
|
||||||
set_reference_tags,
|
set_reference_tags,
|
||||||
|
update_asset_hash_and_mime,
|
||||||
upsert_asset,
|
upsert_asset,
|
||||||
upsert_reference,
|
upsert_reference,
|
||||||
validate_tags_exist,
|
validate_tags_exist,
|
||||||
@ -26,6 +27,7 @@ from app.assets.helpers import normalize_tags
|
|||||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||||
from app.assets.services.path_utils import (
|
from app.assets.services.path_utils import (
|
||||||
compute_relative_filename,
|
compute_relative_filename,
|
||||||
|
get_name_and_tags_from_asset_path,
|
||||||
resolve_destination_from_tags,
|
resolve_destination_from_tags,
|
||||||
validate_path_within_base,
|
validate_path_within_base,
|
||||||
)
|
)
|
||||||
@ -65,7 +67,7 @@ def _ingest_file_from_path(
|
|||||||
|
|
||||||
with create_session() as session:
|
with create_session() as session:
|
||||||
if preview_id:
|
if preview_id:
|
||||||
if preview_id not in get_existing_asset_ids(session, [preview_id]):
|
if not reference_exists(session, preview_id):
|
||||||
preview_id = None
|
preview_id = None
|
||||||
|
|
||||||
asset, asset_created, asset_updated = upsert_asset(
|
asset, asset_created, asset_updated = upsert_asset(
|
||||||
@ -135,6 +137,8 @@ def _register_existing_asset(
|
|||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
tag_origin: str = "manual",
|
tag_origin: str = "manual",
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
|
mime_type: str | None = None,
|
||||||
|
preview_id: str | None = None,
|
||||||
) -> RegisterAssetResult:
|
) -> RegisterAssetResult:
|
||||||
user_metadata = user_metadata or {}
|
user_metadata = user_metadata or {}
|
||||||
|
|
||||||
@ -143,14 +147,25 @@ def _register_existing_asset(
|
|||||||
if not asset:
|
if not asset:
|
||||||
raise ValueError(f"No asset with hash {asset_hash}")
|
raise ValueError(f"No asset with hash {asset_hash}")
|
||||||
|
|
||||||
|
if mime_type and not asset.mime_type:
|
||||||
|
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
|
||||||
|
|
||||||
|
if preview_id:
|
||||||
|
if not reference_exists(session, preview_id):
|
||||||
|
preview_id = None
|
||||||
|
|
||||||
ref, ref_created = get_or_create_reference(
|
ref, ref_created = get_or_create_reference(
|
||||||
session,
|
session,
|
||||||
asset_id=asset.id,
|
asset_id=asset.id,
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
name=name,
|
name=name,
|
||||||
|
preview_id=preview_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if not ref_created:
|
if not ref_created:
|
||||||
|
if preview_id and ref.preview_id != preview_id:
|
||||||
|
ref.preview_id = preview_id
|
||||||
|
|
||||||
tag_names = get_reference_tags(session, reference_id=ref.id)
|
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||||
result = RegisterAssetResult(
|
result = RegisterAssetResult(
|
||||||
ref=extract_reference_data(ref),
|
ref=extract_reference_data(ref),
|
||||||
@ -242,6 +257,8 @@ def upload_from_temp_path(
|
|||||||
client_filename: str | None = None,
|
client_filename: str | None = None,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
expected_hash: str | None = None,
|
expected_hash: str | None = None,
|
||||||
|
mime_type: str | None = None,
|
||||||
|
preview_id: str | None = None,
|
||||||
) -> UploadResult:
|
) -> UploadResult:
|
||||||
try:
|
try:
|
||||||
digest, _ = hashing.compute_blake3_hash(temp_path)
|
digest, _ = hashing.compute_blake3_hash(temp_path)
|
||||||
@ -270,6 +287,8 @@ def upload_from_temp_path(
|
|||||||
tags=tags or [],
|
tags=tags or [],
|
||||||
tag_origin="manual",
|
tag_origin="manual",
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
|
mime_type=mime_type,
|
||||||
|
preview_id=preview_id,
|
||||||
)
|
)
|
||||||
return UploadResult(
|
return UploadResult(
|
||||||
ref=result.ref,
|
ref=result.ref,
|
||||||
@ -291,7 +310,7 @@ def upload_from_temp_path(
|
|||||||
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
|
||||||
validate_path_within_base(dest_abs, base_dir)
|
validate_path_within_base(dest_abs, base_dir)
|
||||||
|
|
||||||
content_type = (
|
content_type = mime_type or (
|
||||||
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
|
||||||
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
or mimetypes.guess_type(hashed_basename, strict=False)[0]
|
||||||
or "application/octet-stream"
|
or "application/octet-stream"
|
||||||
@ -315,7 +334,7 @@ def upload_from_temp_path(
|
|||||||
mime_type=content_type,
|
mime_type=content_type,
|
||||||
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
info_name=_sanitize_filename(name or client_filename, fallback=digest),
|
||||||
owner_id=owner_id,
|
owner_id=owner_id,
|
||||||
preview_id=None,
|
preview_id=preview_id,
|
||||||
user_metadata=user_metadata or {},
|
user_metadata=user_metadata or {},
|
||||||
tags=tags,
|
tags=tags,
|
||||||
tag_origin="manual",
|
tag_origin="manual",
|
||||||
@ -342,30 +361,99 @@ def upload_from_temp_path(
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def register_file_in_place(
|
||||||
|
abs_path: str,
|
||||||
|
name: str,
|
||||||
|
tags: list[str],
|
||||||
|
owner_id: str = "",
|
||||||
|
mime_type: str | None = None,
|
||||||
|
) -> UploadResult:
|
||||||
|
"""Register an already-saved file in the asset database without moving it.
|
||||||
|
|
||||||
|
Tags are derived from the filesystem path (root category + subfolder names),
|
||||||
|
merged with any caller-provided tags, matching the behavior of the scanner.
|
||||||
|
If the path is not under a known root, only the caller-provided tags are used.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
|
||||||
|
except ValueError:
|
||||||
|
path_tags = []
|
||||||
|
merged_tags = normalize_tags([*path_tags, *tags])
|
||||||
|
|
||||||
|
try:
|
||||||
|
digest, _ = hashing.compute_blake3_hash(abs_path)
|
||||||
|
except ImportError as e:
|
||||||
|
raise DependencyMissingError(str(e))
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"failed to hash file: {e}")
|
||||||
|
asset_hash = "blake3:" + digest
|
||||||
|
|
||||||
|
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
|
||||||
|
content_type = mime_type or (
|
||||||
|
mimetypes.guess_type(abs_path, strict=False)[0]
|
||||||
|
or "application/octet-stream"
|
||||||
|
)
|
||||||
|
|
||||||
|
ingest_result = _ingest_file_from_path(
|
||||||
|
abs_path=abs_path,
|
||||||
|
asset_hash=asset_hash,
|
||||||
|
size_bytes=size_bytes,
|
||||||
|
mtime_ns=mtime_ns,
|
||||||
|
mime_type=content_type,
|
||||||
|
info_name=_sanitize_filename(name, fallback=digest),
|
||||||
|
owner_id=owner_id,
|
||||||
|
tags=merged_tags,
|
||||||
|
tag_origin="upload",
|
||||||
|
require_existing_tags=False,
|
||||||
|
)
|
||||||
|
reference_id = ingest_result.reference_id
|
||||||
|
if not reference_id:
|
||||||
|
raise RuntimeError("failed to create asset reference")
|
||||||
|
|
||||||
|
with create_session() as session:
|
||||||
|
pair = fetch_reference_and_asset(
|
||||||
|
session, reference_id=reference_id, owner_id=owner_id
|
||||||
|
)
|
||||||
|
if not pair:
|
||||||
|
raise RuntimeError("inconsistent DB state after ingest")
|
||||||
|
ref, asset = pair
|
||||||
|
tag_names = get_reference_tags(session, reference_id=ref.id)
|
||||||
|
|
||||||
|
return UploadResult(
|
||||||
|
ref=extract_reference_data(ref),
|
||||||
|
asset=extract_asset_data(asset),
|
||||||
|
tags=tag_names,
|
||||||
|
created_new=ingest_result.asset_created,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def create_from_hash(
|
def create_from_hash(
|
||||||
hash_str: str,
|
hash_str: str,
|
||||||
name: str,
|
name: str,
|
||||||
tags: list[str] | None = None,
|
tags: list[str] | None = None,
|
||||||
user_metadata: dict | None = None,
|
user_metadata: dict | None = None,
|
||||||
owner_id: str = "",
|
owner_id: str = "",
|
||||||
|
mime_type: str | None = None,
|
||||||
|
preview_id: str | None = None,
|
||||||
) -> UploadResult | None:
|
) -> UploadResult | None:
|
||||||
canonical = hash_str.strip().lower()
|
canonical = hash_str.strip().lower()
|
||||||
|
|
||||||
with create_session() as session:
|
try:
|
||||||
asset = get_asset_by_hash(session, asset_hash=canonical)
|
result = _register_existing_asset(
|
||||||
if not asset:
|
asset_hash=canonical,
|
||||||
return None
|
name=_sanitize_filename(
|
||||||
|
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
||||||
result = _register_existing_asset(
|
),
|
||||||
asset_hash=canonical,
|
user_metadata=user_metadata or {},
|
||||||
name=_sanitize_filename(
|
tags=tags or [],
|
||||||
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
|
tag_origin="manual",
|
||||||
),
|
owner_id=owner_id,
|
||||||
user_metadata=user_metadata or {},
|
mime_type=mime_type,
|
||||||
tags=tags or [],
|
preview_id=preview_id,
|
||||||
tag_origin="manual",
|
)
|
||||||
owner_id=owner_id,
|
except ValueError:
|
||||||
)
|
logging.warning("create_from_hash: no asset found for hash %s", canonical)
|
||||||
|
return None
|
||||||
|
|
||||||
return UploadResult(
|
return UploadResult(
|
||||||
ref=result.ref,
|
ref=result.ref,
|
||||||
|
|||||||
@ -25,7 +25,9 @@ class ReferenceData:
|
|||||||
preview_id: str | None
|
preview_id: str | None
|
||||||
created_at: datetime
|
created_at: datetime
|
||||||
updated_at: datetime
|
updated_at: datetime
|
||||||
last_access_time: datetime | None
|
system_metadata: dict[str, Any] | None = None
|
||||||
|
job_id: str | None = None
|
||||||
|
last_access_time: datetime | None = None
|
||||||
|
|
||||||
|
|
||||||
@dataclass(frozen=True)
|
@dataclass(frozen=True)
|
||||||
@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
|
|||||||
file_path=ref.file_path,
|
file_path=ref.file_path,
|
||||||
user_metadata=ref.user_metadata,
|
user_metadata=ref.user_metadata,
|
||||||
preview_id=ref.preview_id,
|
preview_id=ref.preview_id,
|
||||||
|
system_metadata=ref.system_metadata,
|
||||||
|
job_id=ref.job_id,
|
||||||
created_at=ref.created_at,
|
created_at=ref.created_at,
|
||||||
updated_at=ref.updated_at,
|
updated_at=ref.updated_at,
|
||||||
last_access_time=ref.last_access_time,
|
last_access_time=ref.last_access_time,
|
||||||
|
|||||||
@ -1,3 +1,5 @@
|
|||||||
|
from typing import Sequence
|
||||||
|
|
||||||
from app.assets.database.queries import (
|
from app.assets.database.queries import (
|
||||||
AddTagsResult,
|
AddTagsResult,
|
||||||
RemoveTagsResult,
|
RemoveTagsResult,
|
||||||
@ -6,6 +8,7 @@ from app.assets.database.queries import (
|
|||||||
list_tags_with_usage,
|
list_tags_with_usage,
|
||||||
remove_tags_from_reference,
|
remove_tags_from_reference,
|
||||||
)
|
)
|
||||||
|
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
|
||||||
from app.assets.services.schemas import TagUsage
|
from app.assets.services.schemas import TagUsage
|
||||||
from app.database.db import create_session
|
from app.database.db import create_session
|
||||||
|
|
||||||
@ -73,3 +76,23 @@ def list_tags(
|
|||||||
)
|
)
|
||||||
|
|
||||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||||
|
|
||||||
|
|
||||||
|
def list_tag_histogram(
|
||||||
|
owner_id: str = "",
|
||||||
|
include_tags: Sequence[str] | None = None,
|
||||||
|
exclude_tags: Sequence[str] | None = None,
|
||||||
|
name_contains: str | None = None,
|
||||||
|
metadata_filter: dict | None = None,
|
||||||
|
limit: int = 100,
|
||||||
|
) -> dict[str, int]:
|
||||||
|
with create_session() as session:
|
||||||
|
return list_tag_counts_for_filtered_assets(
|
||||||
|
session,
|
||||||
|
owner_id=owner_id,
|
||||||
|
include_tags=include_tags,
|
||||||
|
exclude_tags=exclude_tags,
|
||||||
|
name_contains=name_contains,
|
||||||
|
metadata_filter=metadata_filter,
|
||||||
|
limit=limit,
|
||||||
|
)
|
||||||
|
|||||||
@ -1,9 +1,18 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from sqlalchemy import MetaData
|
||||||
from sqlalchemy.orm import DeclarativeBase
|
from sqlalchemy.orm import DeclarativeBase
|
||||||
|
|
||||||
|
NAMING_CONVENTION = {
|
||||||
|
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
|
||||||
|
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
|
||||||
|
"ck": "ck_%(table_name)s_%(constraint_name)s",
|
||||||
|
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
|
||||||
|
"pk": "pk_%(table_name)s",
|
||||||
|
}
|
||||||
|
|
||||||
class Base(DeclarativeBase):
|
class Base(DeclarativeBase):
|
||||||
pass
|
metadata = MetaData(naming_convention=NAMING_CONVENTION)
|
||||||
|
|
||||||
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
|
||||||
fields = obj.__table__.columns.keys()
|
fields = obj.__table__.columns.keys()
|
||||||
|
|||||||
@ -6,6 +6,7 @@ import uuid
|
|||||||
import glob
|
import glob
|
||||||
import shutil
|
import shutil
|
||||||
import logging
|
import logging
|
||||||
|
import tempfile
|
||||||
from aiohttp import web
|
from aiohttp import web
|
||||||
from urllib import parse
|
from urllib import parse
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
@ -377,8 +378,15 @@ class UserManager():
|
|||||||
try:
|
try:
|
||||||
body = await request.read()
|
body = await request.read()
|
||||||
|
|
||||||
with open(path, "wb") as f:
|
dir_name = os.path.dirname(path)
|
||||||
f.write(body)
|
fd, tmp_path = tempfile.mkstemp(dir=dir_name)
|
||||||
|
try:
|
||||||
|
with os.fdopen(fd, "wb") as f:
|
||||||
|
f.write(body)
|
||||||
|
os.replace(tmp_path, path)
|
||||||
|
except:
|
||||||
|
os.unlink(tmp_path)
|
||||||
|
raise
|
||||||
except OSError as e:
|
except OSError as e:
|
||||||
logging.warning(f"Error saving file '{path}': {e}")
|
logging.warning(f"Error saving file '{path}': {e}")
|
||||||
return web.Response(
|
return web.Response(
|
||||||
|
|||||||
@ -83,6 +83,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
|
|||||||
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
|
||||||
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
|
||||||
|
|
||||||
|
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
|
||||||
|
|
||||||
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
|
||||||
|
|
||||||
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
|
||||||
@ -147,6 +149,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
|
|||||||
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
|
||||||
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
|
||||||
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
|
||||||
|
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
|
||||||
|
|
||||||
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
|
||||||
|
|
||||||
@ -260,4 +263,6 @@ else:
|
|||||||
args.fast = set(args.fast)
|
args.fast = set(args.fast)
|
||||||
|
|
||||||
def enables_dynamic_vram():
|
def enables_dynamic_vram():
|
||||||
|
if args.enable_dynamic_vram:
|
||||||
|
return True
|
||||||
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu
|
||||||
|
|||||||
@ -93,6 +93,50 @@ class IndexListCallbacks:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
|
||||||
|
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
|
||||||
|
return None
|
||||||
|
cond_tensor = cond_value.cond
|
||||||
|
if temporal_dim >= cond_tensor.ndim:
|
||||||
|
return None
|
||||||
|
|
||||||
|
cond_size = cond_tensor.size(temporal_dim)
|
||||||
|
|
||||||
|
if temporal_scale == 1:
|
||||||
|
expected_size = x_in.size(window.dim) - temporal_offset
|
||||||
|
if cond_size != expected_size:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if temporal_offset == 0 and temporal_scale == 1:
|
||||||
|
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
|
||||||
|
return cond_value._copy_with(sliced)
|
||||||
|
|
||||||
|
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
|
||||||
|
if temporal_offset > 0:
|
||||||
|
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
|
||||||
|
indices = [i for i in indices if 0 <= i]
|
||||||
|
else:
|
||||||
|
indices = list(window.index_list)
|
||||||
|
|
||||||
|
if not indices:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if temporal_scale > 1:
|
||||||
|
scaled = []
|
||||||
|
for i in indices:
|
||||||
|
for k in range(temporal_scale):
|
||||||
|
si = i * temporal_scale + k
|
||||||
|
if si < cond_size:
|
||||||
|
scaled.append(si)
|
||||||
|
indices = scaled
|
||||||
|
if not indices:
|
||||||
|
return None
|
||||||
|
|
||||||
|
idx = tuple([slice(None)] * temporal_dim + [indices])
|
||||||
|
sliced = cond_tensor[idx].to(device)
|
||||||
|
return cond_value._copy_with(sliced)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ContextSchedule:
|
class ContextSchedule:
|
||||||
name: str
|
name: str
|
||||||
@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
new_cond_item[cond_key] = result
|
new_cond_item[cond_key] = result
|
||||||
handled = True
|
handled = True
|
||||||
break
|
break
|
||||||
|
if not handled and self._model is not None:
|
||||||
|
result = self._model.resize_cond_for_context_window(
|
||||||
|
cond_key, cond_value, window, x_in, device,
|
||||||
|
retain_index_list=self.cond_retain_index_list)
|
||||||
|
if result is not None:
|
||||||
|
new_cond_item[cond_key] = result
|
||||||
|
handled = True
|
||||||
if handled:
|
if handled:
|
||||||
continue
|
continue
|
||||||
if isinstance(cond_value, torch.Tensor):
|
if isinstance(cond_value, torch.Tensor):
|
||||||
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \
|
if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
|
||||||
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
|
||||||
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
new_cond_item[cond_key] = window.get_tensor(cond_value, device)
|
||||||
# Handle audio_embed (temporal dim is 1)
|
# Handle audio_embed (temporal dim is 1)
|
||||||
@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC):
|
|||||||
return context_windows
|
return context_windows
|
||||||
|
|
||||||
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
|
||||||
|
self._model = model
|
||||||
self.set_step(timestep, model_options)
|
self.set_step(timestep, model_options)
|
||||||
context_windows = self.get_context_windows(model, x_in, model_options)
|
context_windows = self.get_context_windows(model, x_in, model_options)
|
||||||
enumerated_context_windows = list(enumerate(context_windows))
|
enumerated_context_windows = list(enumerate(context_windows))
|
||||||
|
|||||||
@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
|
|||||||
output_block[i:i + slice_size].copy_(block)
|
output_block[i:i + slice_size].copy_(block)
|
||||||
|
|
||||||
return output_fp4, to_blocked(output_block, flatten=False)
|
return output_fp4, to_blocked(output_block, flatten=False)
|
||||||
|
|
||||||
|
|
||||||
|
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
|
||||||
|
def roundup(x_val, multiple):
|
||||||
|
return ((x_val + multiple - 1) // multiple) * multiple
|
||||||
|
|
||||||
|
if pad_32x:
|
||||||
|
rows, cols = x.shape
|
||||||
|
padded_rows = roundup(rows, 32)
|
||||||
|
padded_cols = roundup(cols, 32)
|
||||||
|
if padded_rows != rows or padded_cols != cols:
|
||||||
|
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
|
||||||
|
|
||||||
|
F8_E4M3_MAX = 448.0
|
||||||
|
E8M0_BIAS = 127
|
||||||
|
BLOCK_SIZE = 32
|
||||||
|
|
||||||
|
rows, cols = x.shape
|
||||||
|
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
|
||||||
|
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
|
||||||
|
|
||||||
|
# E8M0 block scales (power-of-2 exponents)
|
||||||
|
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
|
||||||
|
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
|
||||||
|
block_scales_e8m0 = exp_biased.to(torch.uint8)
|
||||||
|
|
||||||
|
zero_mask = (max_abs == 0)
|
||||||
|
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
|
||||||
|
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
|
||||||
|
|
||||||
|
# Scale per-block then stochastic round
|
||||||
|
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
|
||||||
|
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
|
||||||
|
|
||||||
|
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
|
||||||
|
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)
|
||||||
|
|||||||
@ -136,16 +136,7 @@ class ResBlock(nn.Module):
|
|||||||
ops.Linear(c_hidden, c),
|
ops.Linear(c_hidden, c),
|
||||||
)
|
)
|
||||||
|
|
||||||
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True)
|
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
|
||||||
|
|
||||||
# Init weights
|
|
||||||
def _basic_init(module):
|
|
||||||
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
|
|
||||||
torch.nn.init.xavier_uniform_(module.weight)
|
|
||||||
if module.bias is not None:
|
|
||||||
nn.init.constant_(module.bias, 0)
|
|
||||||
|
|
||||||
self.apply(_basic_init)
|
|
||||||
|
|
||||||
def _norm(self, x, norm):
|
def _norm(self, x, norm):
|
||||||
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
||||||
|
|||||||
@ -343,6 +343,7 @@ class CrossAttention(nn.Module):
|
|||||||
k.reshape(b, s2, self.num_heads * self.head_dim),
|
k.reshape(b, s2, self.num_heads * self.head_dim),
|
||||||
v,
|
v,
|
||||||
heads=self.num_heads,
|
heads=self.num_heads,
|
||||||
|
low_precision_attention=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
out = self.out_proj(x)
|
out = self.out_proj(x)
|
||||||
@ -412,6 +413,7 @@ class Attention(nn.Module):
|
|||||||
key.reshape(B, N, self.num_heads * self.head_dim),
|
key.reshape(B, N, self.num_heads * self.head_dim),
|
||||||
value,
|
value,
|
||||||
heads=self.num_heads,
|
heads=self.num_heads,
|
||||||
|
low_precision_attention=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
x = self.out_proj(x)
|
x = self.out_proj(x)
|
||||||
|
|||||||
@ -23,6 +23,11 @@ class CausalConv3d(nn.Module):
|
|||||||
self.in_channels = in_channels
|
self.in_channels = in_channels
|
||||||
self.out_channels = out_channels
|
self.out_channels = out_channels
|
||||||
|
|
||||||
|
if isinstance(stride, int):
|
||||||
|
self.time_stride = stride
|
||||||
|
else:
|
||||||
|
self.time_stride = stride[0]
|
||||||
|
|
||||||
kernel_size = (kernel_size, kernel_size, kernel_size)
|
kernel_size = (kernel_size, kernel_size, kernel_size)
|
||||||
self.time_kernel_size = kernel_size[0]
|
self.time_kernel_size = kernel_size[0]
|
||||||
|
|
||||||
@ -58,16 +63,25 @@ class CausalConv3d(nn.Module):
|
|||||||
pieces = [ cached, x ]
|
pieces = [ cached, x ]
|
||||||
if is_end and not causal:
|
if is_end and not causal:
|
||||||
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
|
||||||
|
input_length = sum([piece.shape[2] for piece in pieces])
|
||||||
|
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
|
||||||
|
|
||||||
needs_caching = not is_end
|
needs_caching = not is_end
|
||||||
if needs_caching and x.shape[2] >= self.time_kernel_size - 1:
|
if needs_caching and cache_length == 0:
|
||||||
|
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
|
||||||
needs_caching = False
|
needs_caching = False
|
||||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
if needs_caching and x.shape[2] >= cache_length:
|
||||||
|
needs_caching = False
|
||||||
|
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||||
|
|
||||||
x = torch.cat(pieces, dim=2)
|
x = torch.cat(pieces, dim=2)
|
||||||
|
del pieces
|
||||||
|
del cached
|
||||||
|
|
||||||
if needs_caching:
|
if needs_caching:
|
||||||
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False)
|
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
|
||||||
|
elif is_end:
|
||||||
|
self.temporal_cache_state[tid] = (None, True)
|
||||||
|
|
||||||
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]
|
||||||
|
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from .causal_conv3d import CausalConv3d
|
|||||||
from .pixel_norm import PixelNorm
|
from .pixel_norm import PixelNorm
|
||||||
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
|
||||||
import comfy.ops
|
import comfy.ops
|
||||||
|
import comfy.model_management
|
||||||
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
|
||||||
|
|
||||||
ops = comfy.ops.disable_weight_init
|
ops = comfy.ops.disable_weight_init
|
||||||
@ -232,10 +233,7 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
|
||||||
r"""The forward method of the `Encoder` class."""
|
|
||||||
|
|
||||||
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
|
||||||
sample = self.conv_in(sample)
|
sample = self.conv_in(sample)
|
||||||
|
|
||||||
checkpoint_fn = (
|
checkpoint_fn = (
|
||||||
@ -246,10 +244,14 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
for down_block in self.down_blocks:
|
for down_block in self.down_blocks:
|
||||||
sample = checkpoint_fn(down_block)(sample)
|
sample = checkpoint_fn(down_block)(sample)
|
||||||
|
if sample is None or sample.shape[2] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
sample = self.conv_norm_out(sample)
|
sample = self.conv_norm_out(sample)
|
||||||
sample = self.conv_act(sample)
|
sample = self.conv_act(sample)
|
||||||
sample = self.conv_out(sample)
|
sample = self.conv_out(sample)
|
||||||
|
if sample is None or sample.shape[2] == 0:
|
||||||
|
return None
|
||||||
|
|
||||||
if self.latent_log_var == "uniform":
|
if self.latent_log_var == "uniform":
|
||||||
last_channel = sample[:, -1:, ...]
|
last_channel = sample[:, -1:, ...]
|
||||||
@ -281,9 +283,35 @@ class Encoder(nn.Module):
|
|||||||
|
|
||||||
return sample
|
return sample
|
||||||
|
|
||||||
|
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
|
||||||
|
r"""The forward method of the `Encoder` class."""
|
||||||
|
|
||||||
|
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
|
||||||
|
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
|
||||||
|
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
|
||||||
|
|
||||||
|
outputs = []
|
||||||
|
samples = [sample[:, :, :1, :, :]]
|
||||||
|
if sample.shape[2] > 1:
|
||||||
|
chunk_t = max(2, max_chunk_size // frame_size)
|
||||||
|
if chunk_t < 4:
|
||||||
|
chunk_t = 2
|
||||||
|
elif chunk_t < 8:
|
||||||
|
chunk_t = 4
|
||||||
|
else:
|
||||||
|
chunk_t = (chunk_t // 8) * 8
|
||||||
|
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
|
||||||
|
for chunk_idx, chunk in enumerate(samples):
|
||||||
|
if chunk_idx == len(samples) - 1:
|
||||||
|
mark_conv3d_ended(self)
|
||||||
|
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
|
||||||
|
output = self._forward_chunk(chunk)
|
||||||
|
if output is not None:
|
||||||
|
outputs.append(output)
|
||||||
|
|
||||||
|
return torch_cat_if_needed(outputs, dim=2)
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
#No encoder support so just flag the end so it doesnt use the cache.
|
|
||||||
mark_conv3d_ended(self)
|
|
||||||
try:
|
try:
|
||||||
return self.forward_orig(*args, **kwargs)
|
return self.forward_orig(*args, **kwargs)
|
||||||
finally:
|
finally:
|
||||||
@ -296,7 +324,23 @@ class Encoder(nn.Module):
|
|||||||
module.temporal_cache_state.pop(tid, None)
|
module.temporal_cache_state.pop(tid, None)
|
||||||
|
|
||||||
|
|
||||||
MAX_CHUNK_SIZE=(128 * 1024 ** 2)
|
MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
|
||||||
|
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
|
||||||
|
MIN_CHUNK_SIZE = 32 * 1024 ** 2
|
||||||
|
MAX_CHUNK_SIZE = 128 * 1024 ** 2
|
||||||
|
|
||||||
|
def get_max_chunk_size(device: torch.device) -> int:
|
||||||
|
total_memory = comfy.model_management.get_total_memory(dev=device)
|
||||||
|
|
||||||
|
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
|
||||||
|
return MIN_CHUNK_SIZE
|
||||||
|
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
|
||||||
|
return MAX_CHUNK_SIZE
|
||||||
|
|
||||||
|
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
|
||||||
|
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
|
||||||
|
)
|
||||||
|
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
|
||||||
|
|
||||||
class Decoder(nn.Module):
|
class Decoder(nn.Module):
|
||||||
r"""
|
r"""
|
||||||
@ -456,6 +500,17 @@ class Decoder(nn.Module):
|
|||||||
|
|
||||||
self.gradient_checkpointing = False
|
self.gradient_checkpointing = False
|
||||||
|
|
||||||
|
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
|
||||||
|
ts, hs, ws, to = 1, 1, 1, 0
|
||||||
|
for block in self.up_blocks:
|
||||||
|
if isinstance(block, DepthToSpaceUpsample):
|
||||||
|
ts *= block.stride[0]
|
||||||
|
hs *= block.stride[1]
|
||||||
|
ws *= block.stride[2]
|
||||||
|
if block.stride[0] > 1:
|
||||||
|
to = to * block.stride[0] + 1
|
||||||
|
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
|
||||||
|
|
||||||
self.timestep_conditioning = timestep_conditioning
|
self.timestep_conditioning = timestep_conditioning
|
||||||
|
|
||||||
if timestep_conditioning:
|
if timestep_conditioning:
|
||||||
@ -477,11 +532,62 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor:
|
def decode_output_shape(self, input_shape):
|
||||||
|
c, (ts, hs, ws), to = self._output_scale
|
||||||
|
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
|
||||||
|
|
||||||
|
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
|
||||||
|
sample = sample_ref[0]
|
||||||
|
sample_ref[0] = None
|
||||||
|
if idx >= len(self.up_blocks):
|
||||||
|
sample = self.conv_norm_out(sample)
|
||||||
|
if timestep_shift_scale is not None:
|
||||||
|
shift, scale = timestep_shift_scale
|
||||||
|
sample = sample * (1 + scale) + shift
|
||||||
|
sample = self.conv_act(sample)
|
||||||
|
if ended:
|
||||||
|
mark_conv3d_ended(self.conv_out)
|
||||||
|
sample = self.conv_out(sample, causal=self.causal)
|
||||||
|
if sample is not None and sample.shape[2] > 0:
|
||||||
|
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
||||||
|
t = sample.shape[2]
|
||||||
|
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
|
||||||
|
output_offset[0] += t
|
||||||
|
return
|
||||||
|
|
||||||
|
up_block = self.up_blocks[idx]
|
||||||
|
if ended:
|
||||||
|
mark_conv3d_ended(up_block)
|
||||||
|
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
||||||
|
sample = checkpoint_fn(up_block)(
|
||||||
|
sample, causal=self.causal, timestep=scaled_timestep
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
||||||
|
|
||||||
|
if sample is None or sample.shape[2] == 0:
|
||||||
|
return
|
||||||
|
|
||||||
|
total_bytes = sample.numel() * sample.element_size()
|
||||||
|
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
|
||||||
|
|
||||||
|
if num_chunks == 1:
|
||||||
|
# when we are not chunking, detach our x so the callee can free it as soon as they are done
|
||||||
|
next_sample_ref = [sample]
|
||||||
|
del sample
|
||||||
|
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
||||||
|
|
||||||
|
for chunk_idx, sample1 in enumerate(samples):
|
||||||
|
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||||
|
|
||||||
def forward_orig(
|
def forward_orig(
|
||||||
self,
|
self,
|
||||||
sample: torch.FloatTensor,
|
sample: torch.FloatTensor,
|
||||||
timestep: Optional[torch.Tensor] = None,
|
timestep: Optional[torch.Tensor] = None,
|
||||||
|
output_buffer: Optional[torch.Tensor] = None,
|
||||||
) -> torch.FloatTensor:
|
) -> torch.FloatTensor:
|
||||||
r"""The forward method of the `Decoder` class."""
|
r"""The forward method of the `Decoder` class."""
|
||||||
batch_size = sample.shape[0]
|
batch_size = sample.shape[0]
|
||||||
@ -496,6 +602,7 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
|
|
||||||
timestep_shift_scale = None
|
timestep_shift_scale = None
|
||||||
|
scaled_timestep = None
|
||||||
if self.timestep_conditioning:
|
if self.timestep_conditioning:
|
||||||
assert (
|
assert (
|
||||||
timestep is not None
|
timestep is not None
|
||||||
@ -523,48 +630,18 @@ class Decoder(nn.Module):
|
|||||||
)
|
)
|
||||||
timestep_shift_scale = ada_values.unbind(dim=1)
|
timestep_shift_scale = ada_values.unbind(dim=1)
|
||||||
|
|
||||||
output = []
|
if output_buffer is None:
|
||||||
|
output_buffer = torch.empty(
|
||||||
|
self.decode_output_shape(sample.shape),
|
||||||
|
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
|
||||||
|
)
|
||||||
|
output_offset = [0]
|
||||||
|
|
||||||
def run_up(idx, sample, ended):
|
max_chunk_size = get_max_chunk_size(sample.device)
|
||||||
if idx >= len(self.up_blocks):
|
|
||||||
sample = self.conv_norm_out(sample)
|
|
||||||
if timestep_shift_scale is not None:
|
|
||||||
shift, scale = timestep_shift_scale
|
|
||||||
sample = sample * (1 + scale) + shift
|
|
||||||
sample = self.conv_act(sample)
|
|
||||||
if ended:
|
|
||||||
mark_conv3d_ended(self.conv_out)
|
|
||||||
sample = self.conv_out(sample, causal=self.causal)
|
|
||||||
if sample is not None and sample.shape[2] > 0:
|
|
||||||
output.append(sample)
|
|
||||||
return
|
|
||||||
|
|
||||||
up_block = self.up_blocks[idx]
|
self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
|
||||||
if (ended):
|
|
||||||
mark_conv3d_ended(up_block)
|
|
||||||
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
|
|
||||||
sample = checkpoint_fn(up_block)(
|
|
||||||
sample, causal=self.causal, timestep=scaled_timestep
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
|
|
||||||
|
|
||||||
if sample is None or sample.shape[2] == 0:
|
return output_buffer
|
||||||
return
|
|
||||||
|
|
||||||
total_bytes = sample.numel() * sample.element_size()
|
|
||||||
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
|
|
||||||
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
|
|
||||||
|
|
||||||
for chunk_idx, sample1 in enumerate(samples):
|
|
||||||
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
|
|
||||||
|
|
||||||
run_up(0, sample, True)
|
|
||||||
sample = torch.cat(output, dim=2)
|
|
||||||
|
|
||||||
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
|
|
||||||
|
|
||||||
return sample
|
|
||||||
|
|
||||||
def forward(self, *args, **kwargs):
|
def forward(self, *args, **kwargs):
|
||||||
try:
|
try:
|
||||||
@ -688,12 +765,25 @@ class SpaceToDepthDownsample(nn.Module):
|
|||||||
causal=True,
|
causal=True,
|
||||||
spatial_padding_mode=spatial_padding_mode,
|
spatial_padding_mode=spatial_padding_mode,
|
||||||
)
|
)
|
||||||
|
self.temporal_cache_state = {}
|
||||||
|
|
||||||
def forward(self, x, causal: bool = True):
|
def forward(self, x, causal: bool = True):
|
||||||
if self.stride[0] == 2:
|
tid = threading.get_ident()
|
||||||
|
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
|
||||||
|
if cached_input is not None:
|
||||||
|
x = torch_cat_if_needed([cached_input, x], dim=2)
|
||||||
|
cached_input = None
|
||||||
|
|
||||||
|
if self.stride[0] == 2 and pad_first:
|
||||||
x = torch.cat(
|
x = torch.cat(
|
||||||
[x[:, :, :1, :, :], x], dim=2
|
[x[:, :, :1, :, :], x], dim=2
|
||||||
) # duplicate first frames for padding
|
) # duplicate first frames for padding
|
||||||
|
pad_first = False
|
||||||
|
|
||||||
|
if x.shape[2] < self.stride[0]:
|
||||||
|
cached_input = x
|
||||||
|
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||||
|
return None
|
||||||
|
|
||||||
# skip connection
|
# skip connection
|
||||||
x_in = rearrange(
|
x_in = rearrange(
|
||||||
@ -708,15 +798,26 @@ class SpaceToDepthDownsample(nn.Module):
|
|||||||
|
|
||||||
# conv
|
# conv
|
||||||
x = self.conv(x, causal=causal)
|
x = self.conv(x, causal=causal)
|
||||||
x = rearrange(
|
if self.stride[0] == 2 and x.shape[2] == 1:
|
||||||
x,
|
if cached_x is not None:
|
||||||
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
x = torch_cat_if_needed([cached_x, x], dim=2)
|
||||||
p1=self.stride[0],
|
cached_x = None
|
||||||
p2=self.stride[1],
|
else:
|
||||||
p3=self.stride[2],
|
cached_x = x
|
||||||
)
|
x = None
|
||||||
|
|
||||||
x = x + x_in
|
if x is not None:
|
||||||
|
x = rearrange(
|
||||||
|
x,
|
||||||
|
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
|
||||||
|
p1=self.stride[0],
|
||||||
|
p2=self.stride[1],
|
||||||
|
p3=self.stride[2],
|
||||||
|
)
|
||||||
|
|
||||||
|
cached = add_exchange_cache(x, cached, x_in, dim=2)
|
||||||
|
|
||||||
|
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
|
||||||
|
|
||||||
return x
|
return x
|
||||||
|
|
||||||
@ -1049,6 +1150,8 @@ class processor(nn.Module):
|
|||||||
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
|
||||||
|
|
||||||
class VideoVAE(nn.Module):
|
class VideoVAE(nn.Module):
|
||||||
|
comfy_has_chunked_io = True
|
||||||
|
|
||||||
def __init__(self, version=0, config=None):
|
def __init__(self, version=0, config=None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
@ -1191,14 +1294,15 @@ class VideoVAE(nn.Module):
|
|||||||
}
|
}
|
||||||
return config
|
return config
|
||||||
|
|
||||||
def encode(self, x):
|
def encode(self, x, device=None):
|
||||||
frames_count = x.shape[2]
|
x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
|
||||||
if ((frames_count - 1) % 8) != 0:
|
means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
|
||||||
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
|
|
||||||
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
|
|
||||||
return self.per_channel_statistics.normalize(means)
|
return self.per_channel_statistics.normalize(means)
|
||||||
|
|
||||||
def decode(self, x):
|
def decode_output_shape(self, input_shape):
|
||||||
|
return self.decoder.decode_output_shape(input_shape)
|
||||||
|
|
||||||
|
def decode(self, x, output_buffer=None):
|
||||||
if self.timestep_conditioning: #TODO: seed
|
if self.timestep_conditioning: #TODO: seed
|
||||||
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
|
||||||
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep)
|
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)
|
||||||
|
|||||||
@ -99,7 +99,7 @@ class Resample(nn.Module):
|
|||||||
else:
|
else:
|
||||||
self.resample = nn.Identity()
|
self.resample = nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||||
b, c, t, h, w = x.size()
|
b, c, t, h, w = x.size()
|
||||||
if self.mode == 'upsample3d':
|
if self.mode == 'upsample3d':
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
@ -109,22 +109,7 @@ class Resample(nn.Module):
|
|||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
else:
|
else:
|
||||||
|
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if cache_x.shape[2] < 2 and feat_cache[
|
|
||||||
idx] is not None and feat_cache[idx] != 'Rep':
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
if cache_x.shape[2] < 2 and feat_cache[
|
|
||||||
idx] is not None and feat_cache[idx] == 'Rep':
|
|
||||||
cache_x = torch.cat([
|
|
||||||
torch.zeros_like(cache_x).to(cache_x.device),
|
|
||||||
cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
if feat_cache[idx] == 'Rep':
|
if feat_cache[idx] == 'Rep':
|
||||||
x = self.time_conv(x)
|
x = self.time_conv(x)
|
||||||
else:
|
else:
|
||||||
@ -145,19 +130,24 @@ class Resample(nn.Module):
|
|||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
if feat_cache[idx] is None:
|
if feat_cache[idx] is None:
|
||||||
feat_cache[idx] = x.clone()
|
feat_cache[idx] = x
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
else:
|
||||||
|
|
||||||
cache_x = x[:, :, -1:, :, :].clone()
|
cache_x = x[:, :, -1:, :, :]
|
||||||
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
|
|
||||||
# # cache last frame of last two chunk
|
|
||||||
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
|
|
||||||
|
|
||||||
x = self.time_conv(
|
x = self.time_conv(
|
||||||
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
|
||||||
|
deferred_x = feat_cache[idx + 1]
|
||||||
|
if deferred_x is not None:
|
||||||
|
x = torch.cat([deferred_x, x], 2)
|
||||||
|
feat_cache[idx + 1] = None
|
||||||
|
|
||||||
|
if x.shape[2] == 1 and not final:
|
||||||
|
feat_cache[idx + 1] = x
|
||||||
|
x = None
|
||||||
|
|
||||||
|
feat_idx[0] += 2
|
||||||
return x
|
return x
|
||||||
|
|
||||||
|
|
||||||
@ -177,19 +167,12 @@ class ResidualBlock(nn.Module):
|
|||||||
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
|
||||||
if in_dim != out_dim else nn.Identity()
|
if in_dim != out_dim else nn.Identity()
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||||
old_x = x
|
old_x = x
|
||||||
for layer in self.residual:
|
for layer in self.residual:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
x = layer(x, cache_list=feat_cache, cache_idx=idx)
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -213,7 +196,7 @@ class AttentionBlock(nn.Module):
|
|||||||
self.proj = ops.Conv2d(dim, dim, 1)
|
self.proj = ops.Conv2d(dim, dim, 1)
|
||||||
self.optimized_attention = vae_attention()
|
self.optimized_attention = vae_attention()
|
||||||
|
|
||||||
def forward(self, x):
|
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||||
identity = x
|
identity = x
|
||||||
b, c, t, h, w = x.size()
|
b, c, t, h, w = x.size()
|
||||||
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
x = rearrange(x, 'b c t h w -> (b t) c h w')
|
||||||
@ -283,17 +266,10 @@ class Encoder3d(nn.Module):
|
|||||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
CausalConv3d(out_dim, z_dim, 3, padding=1))
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = self.conv1(x, feat_cache[idx])
|
x = self.conv1(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -303,14 +279,16 @@ class Encoder3d(nn.Module):
|
|||||||
## downsamples
|
## downsamples
|
||||||
for layer in self.downsamples:
|
for layer in self.downsamples:
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx, final=final)
|
||||||
|
if x is None:
|
||||||
|
return None
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## middle
|
## middle
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx, final=final)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
@ -318,14 +296,7 @@ class Encoder3d(nn.Module):
|
|||||||
for layer in self.head:
|
for layer in self.head:
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = layer(x, feat_cache[idx])
|
x = layer(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -389,18 +360,48 @@ class Decoder3d(nn.Module):
|
|||||||
RMS_norm(out_dim, images=False), nn.SiLU(),
|
RMS_norm(out_dim, images=False), nn.SiLU(),
|
||||||
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
CausalConv3d(out_dim, output_channels, 3, padding=1))
|
||||||
|
|
||||||
|
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
|
||||||
|
x = x_ref[0]
|
||||||
|
x_ref[0] = None
|
||||||
|
if layer_idx >= len(self.upsamples):
|
||||||
|
for layer in self.head:
|
||||||
|
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
||||||
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
|
x = layer(x, feat_cache[feat_idx[0]])
|
||||||
|
feat_cache[feat_idx[0]] = cache_x
|
||||||
|
feat_idx[0] += 1
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
out_chunks.append(x)
|
||||||
|
return
|
||||||
|
|
||||||
|
layer = self.upsamples[layer_idx]
|
||||||
|
if feat_cache is not None:
|
||||||
|
x = layer(x, feat_cache, feat_idx)
|
||||||
|
else:
|
||||||
|
x = layer(x)
|
||||||
|
|
||||||
|
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2:
|
||||||
|
for frame_idx in range(0, x.shape[2], 2):
|
||||||
|
self.run_up(
|
||||||
|
layer_idx + 1,
|
||||||
|
[x[:, :, frame_idx:frame_idx + 2, :, :]],
|
||||||
|
feat_cache,
|
||||||
|
feat_idx.copy(),
|
||||||
|
out_chunks,
|
||||||
|
)
|
||||||
|
del x
|
||||||
|
return
|
||||||
|
|
||||||
|
next_x_ref = [x]
|
||||||
|
del x
|
||||||
|
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
|
||||||
|
|
||||||
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
||||||
## conv1
|
## conv1
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
idx = feat_idx[0]
|
idx = feat_idx[0]
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
cache_x = x[:, :, -CACHE_T:, :, :]
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = self.conv1(x, feat_cache[idx])
|
x = self.conv1(x, feat_cache[idx])
|
||||||
feat_cache[idx] = cache_x
|
feat_cache[idx] = cache_x
|
||||||
feat_idx[0] += 1
|
feat_idx[0] += 1
|
||||||
@ -409,42 +410,21 @@ class Decoder3d(nn.Module):
|
|||||||
|
|
||||||
## middle
|
## middle
|
||||||
for layer in self.middle:
|
for layer in self.middle:
|
||||||
if isinstance(layer, ResidualBlock) and feat_cache is not None:
|
|
||||||
x = layer(x, feat_cache, feat_idx)
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
|
|
||||||
## upsamples
|
|
||||||
for layer in self.upsamples:
|
|
||||||
if feat_cache is not None:
|
if feat_cache is not None:
|
||||||
x = layer(x, feat_cache, feat_idx)
|
x = layer(x, feat_cache, feat_idx)
|
||||||
else:
|
else:
|
||||||
x = layer(x)
|
x = layer(x)
|
||||||
|
|
||||||
## head
|
out_chunks = []
|
||||||
for layer in self.head:
|
|
||||||
if isinstance(layer, CausalConv3d) and feat_cache is not None:
|
self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
|
||||||
idx = feat_idx[0]
|
return out_chunks
|
||||||
cache_x = x[:, :, -CACHE_T:, :, :].clone()
|
|
||||||
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
|
|
||||||
# cache last frame of last two chunk
|
|
||||||
cache_x = torch.cat([
|
|
||||||
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
|
|
||||||
cache_x.device), cache_x
|
|
||||||
],
|
|
||||||
dim=2)
|
|
||||||
x = layer(x, feat_cache[idx])
|
|
||||||
feat_cache[idx] = cache_x
|
|
||||||
feat_idx[0] += 1
|
|
||||||
else:
|
|
||||||
x = layer(x)
|
|
||||||
return x
|
|
||||||
|
|
||||||
|
|
||||||
def count_conv3d(model):
|
def count_cache_layers(model):
|
||||||
count = 0
|
count = 0
|
||||||
for m in model.modules():
|
for m in model.modules():
|
||||||
if isinstance(m, CausalConv3d):
|
if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
|
||||||
count += 1
|
count += 1
|
||||||
return count
|
return count
|
||||||
|
|
||||||
@ -482,11 +462,12 @@ class WanVAE(nn.Module):
|
|||||||
conv_idx = [0]
|
conv_idx = [0]
|
||||||
## cache
|
## cache
|
||||||
t = x.shape[2]
|
t = x.shape[2]
|
||||||
iter_ = 1 + (t - 1) // 4
|
t = 1 + ((t - 1) // 4) * 4
|
||||||
|
iter_ = 1 + (t - 1) // 2
|
||||||
feat_map = None
|
feat_map = None
|
||||||
if iter_ > 1:
|
if iter_ > 1:
|
||||||
feat_map = [None] * count_conv3d(self.encoder)
|
feat_map = [None] * count_cache_layers(self.encoder)
|
||||||
## 对encode输入的x,按时间拆分为1、4、4、4....
|
## 对encode输入的x,按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
conv_idx = [0]
|
conv_idx = [0]
|
||||||
if i == 0:
|
if i == 0:
|
||||||
@ -496,20 +477,23 @@ class WanVAE(nn.Module):
|
|||||||
feat_idx=conv_idx)
|
feat_idx=conv_idx)
|
||||||
else:
|
else:
|
||||||
out_ = self.encoder(
|
out_ = self.encoder(
|
||||||
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :],
|
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||||
feat_cache=feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=conv_idx)
|
feat_idx=conv_idx,
|
||||||
|
final=(i == (iter_ - 1)))
|
||||||
|
if out_ is None:
|
||||||
|
continue
|
||||||
out = torch.cat([out, out_], 2)
|
out = torch.cat([out, out_], 2)
|
||||||
|
|
||||||
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
mu, log_var = self.conv1(out).chunk(2, dim=1)
|
||||||
return mu
|
return mu
|
||||||
|
|
||||||
def decode(self, z):
|
def decode(self, z):
|
||||||
conv_idx = [0]
|
|
||||||
# z: [b,c,t,h,w]
|
# z: [b,c,t,h,w]
|
||||||
iter_ = z.shape[2]
|
iter_ = 1 + z.shape[2] // 2
|
||||||
feat_map = None
|
feat_map = None
|
||||||
if iter_ > 1:
|
if iter_ > 1:
|
||||||
feat_map = [None] * count_conv3d(self.decoder)
|
feat_map = [None] * count_cache_layers(self.decoder)
|
||||||
x = self.conv2(z)
|
x = self.conv2(z)
|
||||||
for i in range(iter_):
|
for i in range(iter_):
|
||||||
conv_idx = [0]
|
conv_idx = [0]
|
||||||
@ -520,8 +504,8 @@ class WanVAE(nn.Module):
|
|||||||
feat_idx=conv_idx)
|
feat_idx=conv_idx)
|
||||||
else:
|
else:
|
||||||
out_ = self.decoder(
|
out_ = self.decoder(
|
||||||
x[:, :, i:i + 1, :, :],
|
x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
|
||||||
feat_cache=feat_map,
|
feat_cache=feat_map,
|
||||||
feat_idx=conv_idx)
|
feat_idx=conv_idx)
|
||||||
out = torch.cat([out, out_], 2)
|
out += out_
|
||||||
return out
|
return torch.cat(out, 2)
|
||||||
|
|||||||
@ -1,9 +1,71 @@
|
|||||||
import math
|
import math
|
||||||
|
import ctypes
|
||||||
|
import threading
|
||||||
|
import dataclasses
|
||||||
import torch
|
import torch
|
||||||
from typing import NamedTuple
|
from typing import NamedTuple
|
||||||
|
|
||||||
from comfy.quant_ops import QuantizedTensor
|
from comfy.quant_ops import QuantizedTensor
|
||||||
|
|
||||||
|
|
||||||
|
class TensorFileSlice(NamedTuple):
|
||||||
|
file_ref: object
|
||||||
|
thread_id: int
|
||||||
|
offset: int
|
||||||
|
size: int
|
||||||
|
|
||||||
|
|
||||||
|
def read_tensor_file_slice_into(tensor, destination):
|
||||||
|
|
||||||
|
if isinstance(tensor, QuantizedTensor):
|
||||||
|
if not isinstance(destination, QuantizedTensor):
|
||||||
|
return False
|
||||||
|
if tensor._layout_cls != destination._layout_cls:
|
||||||
|
return False
|
||||||
|
|
||||||
|
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
|
||||||
|
return False
|
||||||
|
|
||||||
|
dst_orig_dtype = destination._params.orig_dtype
|
||||||
|
destination._params.copy_from(tensor._params, non_blocking=False)
|
||||||
|
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
|
||||||
|
return True
|
||||||
|
|
||||||
|
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
|
||||||
|
if info is None:
|
||||||
|
return False
|
||||||
|
|
||||||
|
file_obj = info.file_ref
|
||||||
|
if (destination.device.type != "cpu"
|
||||||
|
or file_obj is None
|
||||||
|
or threading.get_ident() != info.thread_id
|
||||||
|
or destination.numel() * destination.element_size() < info.size
|
||||||
|
or tensor.numel() * tensor.element_size() != info.size
|
||||||
|
or tensor.storage_offset() != 0
|
||||||
|
or not tensor.is_contiguous()):
|
||||||
|
return False
|
||||||
|
|
||||||
|
if info.size == 0:
|
||||||
|
return True
|
||||||
|
|
||||||
|
buf_type = ctypes.c_ubyte * info.size
|
||||||
|
view = memoryview(buf_type.from_address(destination.data_ptr()))
|
||||||
|
|
||||||
|
try:
|
||||||
|
file_obj.seek(info.offset)
|
||||||
|
done = 0
|
||||||
|
while done < info.size:
|
||||||
|
try:
|
||||||
|
n = file_obj.readinto(view[done:])
|
||||||
|
except OSError:
|
||||||
|
return False
|
||||||
|
if n <= 0:
|
||||||
|
return False
|
||||||
|
done += n
|
||||||
|
return True
|
||||||
|
finally:
|
||||||
|
view.release()
|
||||||
|
|
||||||
class TensorGeometry(NamedTuple):
|
class TensorGeometry(NamedTuple):
|
||||||
shape: any
|
shape: any
|
||||||
dtype: torch.dtype
|
dtype: torch.dtype
|
||||||
|
|||||||
@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
|
|||||||
import torch
|
import torch
|
||||||
import logging
|
import logging
|
||||||
import comfy.ldm.lightricks.av_model
|
import comfy.ldm.lightricks.av_model
|
||||||
|
import comfy.context_windows
|
||||||
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
|
||||||
from comfy.ldm.cascade.stage_c import StageC
|
from comfy.ldm.cascade.stage_c import StageC
|
||||||
from comfy.ldm.cascade.stage_b import StageB
|
from comfy.ldm.cascade.stage_b import StageB
|
||||||
@ -285,6 +286,12 @@ class BaseModel(torch.nn.Module):
|
|||||||
return data
|
return data
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
|
"""Override in subclasses to handle model-specific cond slicing for context windows.
|
||||||
|
Return a sliced cond object, or None to fall through to default handling.
|
||||||
|
Use comfy.context_windows.slice_cond() for common cases."""
|
||||||
|
return None
|
||||||
|
|
||||||
def extra_conds(self, **kwargs):
|
def extra_conds(self, **kwargs):
|
||||||
out = {}
|
out = {}
|
||||||
concat_cond = self.concat_cond(**kwargs)
|
concat_cond = self.concat_cond(**kwargs)
|
||||||
@ -1375,6 +1382,11 @@ class WAN21_Vace(WAN21):
|
|||||||
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
|
if cond_key == "vace_context":
|
||||||
|
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
|
||||||
|
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||||
|
|
||||||
class WAN21_Camera(WAN21):
|
class WAN21_Camera(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
|
||||||
@ -1427,6 +1439,11 @@ class WAN21_HuMo(WAN21):
|
|||||||
|
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
|
if cond_key == "audio_embed":
|
||||||
|
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||||
|
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||||
|
|
||||||
class WAN22_Animate(WAN21):
|
class WAN22_Animate(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
|
||||||
@ -1444,6 +1461,13 @@ class WAN22_Animate(WAN21):
|
|||||||
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
|
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
|
if cond_key == "face_pixel_values":
|
||||||
|
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
|
||||||
|
if cond_key == "pose_latents":
|
||||||
|
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
|
||||||
|
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||||
|
|
||||||
class WAN22_S2V(WAN21):
|
class WAN22_S2V(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
|
||||||
@ -1480,6 +1504,11 @@ class WAN22_S2V(WAN21):
|
|||||||
out['reference_motion'] = reference_motion.shape
|
out['reference_motion'] = reference_motion.shape
|
||||||
return out
|
return out
|
||||||
|
|
||||||
|
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||||
|
if cond_key == "audio_embed":
|
||||||
|
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
|
||||||
|
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||||
|
|
||||||
class WAN22(WAN21):
|
class WAN22(WAN21):
|
||||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)
|
||||||
|
|||||||
@ -400,7 +400,7 @@ try:
|
|||||||
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
|
||||||
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
|
||||||
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
|
||||||
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
|
||||||
ENABLE_PYTORCH_ATTENTION = True
|
ENABLE_PYTORCH_ATTENTION = True
|
||||||
if rocm_version >= (7, 0):
|
if rocm_version >= (7, 0):
|
||||||
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
if any((a in arch) for a in ["gfx1200", "gfx1201"]):
|
||||||
@ -505,6 +505,28 @@ def module_size(module):
|
|||||||
module_mem += t.nbytes
|
module_mem += t.nbytes
|
||||||
return module_mem
|
return module_mem
|
||||||
|
|
||||||
|
def module_mmap_residency(module, free=False):
|
||||||
|
mmap_touched_mem = 0
|
||||||
|
module_mem = 0
|
||||||
|
bounced_mmaps = set()
|
||||||
|
sd = module.state_dict()
|
||||||
|
for k in sd:
|
||||||
|
t = sd[k]
|
||||||
|
module_mem += t.nbytes
|
||||||
|
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
|
||||||
|
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
|
||||||
|
continue
|
||||||
|
mmap_touched_mem += t.nbytes
|
||||||
|
if not free:
|
||||||
|
continue
|
||||||
|
storage._comfy_tensor_mmap_touched = False
|
||||||
|
mmap_obj = storage._comfy_tensor_mmap_refs[0]
|
||||||
|
if mmap_obj in bounced_mmaps:
|
||||||
|
continue
|
||||||
|
mmap_obj.bounce()
|
||||||
|
bounced_mmaps.add(mmap_obj)
|
||||||
|
return mmap_touched_mem, module_mem
|
||||||
|
|
||||||
class LoadedModel:
|
class LoadedModel:
|
||||||
def __init__(self, model):
|
def __init__(self, model):
|
||||||
self._set_model(model)
|
self._set_model(model)
|
||||||
@ -519,6 +541,7 @@ class LoadedModel:
|
|||||||
if model.parent is not None:
|
if model.parent is not None:
|
||||||
self._parent_model = weakref.ref(model.parent)
|
self._parent_model = weakref.ref(model.parent)
|
||||||
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
|
||||||
|
self._patcher_finalizer.atexit = False
|
||||||
|
|
||||||
def _switch_parent(self):
|
def _switch_parent(self):
|
||||||
model = self._parent_model()
|
model = self._parent_model()
|
||||||
@ -532,6 +555,9 @@ class LoadedModel:
|
|||||||
def model_memory(self):
|
def model_memory(self):
|
||||||
return self.model.model_size()
|
return self.model.model_size()
|
||||||
|
|
||||||
|
def model_mmap_residency(self, free=False):
|
||||||
|
return self.model.model_mmap_residency(free=free)
|
||||||
|
|
||||||
def model_loaded_memory(self):
|
def model_loaded_memory(self):
|
||||||
return self.model.loaded_size()
|
return self.model.loaded_size()
|
||||||
|
|
||||||
@ -562,6 +588,7 @@ class LoadedModel:
|
|||||||
|
|
||||||
self.real_model = weakref.ref(real_model)
|
self.real_model = weakref.ref(real_model)
|
||||||
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
self.model_finalizer = weakref.finalize(real_model, cleanup_models)
|
||||||
|
self.model_finalizer.atexit = False
|
||||||
return real_model
|
return real_model
|
||||||
|
|
||||||
def should_reload_model(self, force_patch_weights=False):
|
def should_reload_model(self, force_patch_weights=False):
|
||||||
@ -633,7 +660,7 @@ def extra_reserved_memory():
|
|||||||
def minimum_inference_memory():
|
def minimum_inference_memory():
|
||||||
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
|
||||||
|
|
||||||
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0):
|
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
|
||||||
cleanup_models_gc()
|
cleanup_models_gc()
|
||||||
unloaded_model = []
|
unloaded_model = []
|
||||||
can_unload = []
|
can_unload = []
|
||||||
@ -646,13 +673,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
|
||||||
shift_model.currently_used = False
|
shift_model.currently_used = False
|
||||||
|
|
||||||
for x in sorted(can_unload):
|
can_unload_sorted = sorted(can_unload)
|
||||||
|
for x in can_unload_sorted:
|
||||||
i = x[-1]
|
i = x[-1]
|
||||||
memory_to_free = 1e32
|
memory_to_free = 1e32
|
||||||
ram_to_free = 1e32
|
pins_to_free = 1e32
|
||||||
if not DISABLE_SMART_MEMORY:
|
if not DISABLE_SMART_MEMORY:
|
||||||
memory_to_free = memory_required - get_free_memory(device)
|
memory_to_free = memory_required - get_free_memory(device)
|
||||||
ram_to_free = ram_required - get_free_ram()
|
pins_to_free = pins_required - get_free_ram()
|
||||||
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
if current_loaded_models[i].model.is_dynamic() and for_dynamic:
|
||||||
#don't actually unload dynamic models for the sake of other dynamic models
|
#don't actually unload dynamic models for the sake of other dynamic models
|
||||||
#as that works on-demand.
|
#as that works on-demand.
|
||||||
@ -661,9 +689,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
|
|||||||
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
|
||||||
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
unloaded_model.append(i)
|
unloaded_model.append(i)
|
||||||
if ram_to_free > 0:
|
if pins_to_free > 0:
|
||||||
|
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
|
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
|
||||||
|
|
||||||
|
for x in can_unload_sorted:
|
||||||
|
i = x[-1]
|
||||||
|
ram_to_free = ram_required - psutil.virtual_memory().available
|
||||||
|
if ram_to_free <= 0 and i not in unloaded_model:
|
||||||
|
continue
|
||||||
|
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
|
||||||
|
if resident_memory > 0:
|
||||||
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
|
||||||
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
|
|
||||||
|
|
||||||
for i in sorted(unloaded_model, reverse=True):
|
for i in sorted(unloaded_model, reverse=True):
|
||||||
unloaded_models.append(current_loaded_models.pop(i))
|
unloaded_models.append(current_loaded_models.pop(i))
|
||||||
@ -729,17 +766,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
|
|||||||
|
|
||||||
|
|
||||||
total_memory_required = {}
|
total_memory_required = {}
|
||||||
|
total_pins_required = {}
|
||||||
total_ram_required = {}
|
total_ram_required = {}
|
||||||
for loaded_model in models_to_load:
|
for loaded_model in models_to_load:
|
||||||
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device)
|
device = loaded_model.device
|
||||||
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we
|
total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
|
||||||
#want to do.
|
resident_memory, model_memory = loaded_model.model_mmap_residency()
|
||||||
#FIXME: This should subtract off the to_load current pin consumption.
|
pinned_memory = loaded_model.model.pinned_memory_size()
|
||||||
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2
|
#FIXME: This can over-free the pins as it budgets to pin the entire model. We should
|
||||||
|
#make this JIT to keep as much pinned as possible.
|
||||||
|
pins_required = model_memory - pinned_memory
|
||||||
|
ram_required = model_memory - resident_memory
|
||||||
|
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
|
||||||
|
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device])
|
free_memory(total_memory_required[device] * 1.1 + extra_mem,
|
||||||
|
device,
|
||||||
|
for_dynamic=free_for_dynamic,
|
||||||
|
pins_required=total_pins_required[device],
|
||||||
|
ram_required=total_ram_required[device])
|
||||||
|
|
||||||
for device in total_memory_required:
|
for device in total_memory_required:
|
||||||
if device != torch.device("cpu"):
|
if device != torch.device("cpu"):
|
||||||
@ -1005,6 +1052,12 @@ def intermediate_device():
|
|||||||
else:
|
else:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
|
|
||||||
|
def intermediate_dtype():
|
||||||
|
if args.fp16_intermediates:
|
||||||
|
return torch.float16
|
||||||
|
else:
|
||||||
|
return torch.float32
|
||||||
|
|
||||||
def vae_device():
|
def vae_device():
|
||||||
if args.cpu_vae:
|
if args.cpu_vae:
|
||||||
return torch.device("cpu")
|
return torch.device("cpu")
|
||||||
@ -1225,6 +1278,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
|
|||||||
dest_view = dest_views.pop(0)
|
dest_view = dest_views.pop(0)
|
||||||
if tensor is None:
|
if tensor is None:
|
||||||
continue
|
continue
|
||||||
|
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
|
||||||
|
continue
|
||||||
|
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
|
||||||
|
if hasattr(storage, "_comfy_tensor_mmap_touched"):
|
||||||
|
storage._comfy_tensor_mmap_touched = True
|
||||||
dest_view.copy_(tensor, non_blocking=non_blocking)
|
dest_view.copy_(tensor, non_blocking=non_blocking)
|
||||||
|
|
||||||
|
|
||||||
@ -1662,6 +1720,19 @@ def supports_nvfp4_compute(device=None):
|
|||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
def supports_mxfp8_compute(device=None):
|
||||||
|
if not is_nvidia():
|
||||||
|
return False
|
||||||
|
|
||||||
|
if torch_version_numeric < (2, 10):
|
||||||
|
return False
|
||||||
|
|
||||||
|
props = torch.cuda.get_device_properties(device)
|
||||||
|
if props.major < 10:
|
||||||
|
return False
|
||||||
|
|
||||||
|
return True
|
||||||
|
|
||||||
def extended_fp16_support():
|
def extended_fp16_support():
|
||||||
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
# TODO: check why some models work with fp16 on newer torch versions but not on older
|
||||||
if torch_version_numeric < (2, 7):
|
if torch_version_numeric < (2, 7):
|
||||||
|
|||||||
@ -297,6 +297,9 @@ class ModelPatcher:
|
|||||||
self.size = comfy.model_management.module_size(self.model)
|
self.size = comfy.model_management.module_size(self.model)
|
||||||
return self.size
|
return self.size
|
||||||
|
|
||||||
|
def model_mmap_residency(self, free=False):
|
||||||
|
return comfy.model_management.module_mmap_residency(self.model, free=free)
|
||||||
|
|
||||||
def get_ram_usage(self):
|
def get_ram_usage(self):
|
||||||
return self.model_size()
|
return self.model_size()
|
||||||
|
|
||||||
@ -1063,6 +1066,10 @@ class ModelPatcher:
|
|||||||
|
|
||||||
return self.model.model_loaded_weight_memory - current_used
|
return self.model.model_loaded_weight_memory - current_used
|
||||||
|
|
||||||
|
def pinned_memory_size(self):
|
||||||
|
# Pinned memory pressure tracking is only implemented for DynamicVram loading
|
||||||
|
return 0
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
@ -1653,6 +1660,16 @@ class ModelPatcherDynamic(ModelPatcher):
|
|||||||
|
|
||||||
return freed
|
return freed
|
||||||
|
|
||||||
|
def pinned_memory_size(self):
|
||||||
|
total = 0
|
||||||
|
loading = self._load_list(for_dynamic=True)
|
||||||
|
for x in loading:
|
||||||
|
_, _, _, _, m, _ = x
|
||||||
|
pin = comfy.pinned_memory.get_pin(m)
|
||||||
|
if pin is not None:
|
||||||
|
total += pin.numel() * pin.element_size()
|
||||||
|
return total
|
||||||
|
|
||||||
def partially_unload_ram(self, ram_to_unload):
|
def partially_unload_ram(self, ram_to_unload):
|
||||||
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
|
||||||
for x in loading:
|
for x in loading:
|
||||||
|
|||||||
236
comfy/ops.py
236
comfy/ops.py
@ -306,10 +306,40 @@ class CastWeightBiasOp:
|
|||||||
bias_function = []
|
bias_function = []
|
||||||
|
|
||||||
class disable_weight_init:
|
class disable_weight_init:
|
||||||
|
@staticmethod
|
||||||
|
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
|
||||||
|
missing_keys, unexpected_keys, weight_shape,
|
||||||
|
bias_shape=None):
|
||||||
|
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
||||||
|
prefix_len = len(prefix)
|
||||||
|
for k, v in state_dict.items():
|
||||||
|
key = k[prefix_len:]
|
||||||
|
if key == "weight":
|
||||||
|
if not assign_to_params_buffers:
|
||||||
|
v = v.clone()
|
||||||
|
module.weight = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
elif bias_shape is not None and key == "bias" and v is not None:
|
||||||
|
if not assign_to_params_buffers:
|
||||||
|
v = v.clone()
|
||||||
|
module.bias = torch.nn.Parameter(v, requires_grad=False)
|
||||||
|
else:
|
||||||
|
unexpected_keys.append(k)
|
||||||
|
|
||||||
|
if module.weight is None:
|
||||||
|
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
|
||||||
|
missing_keys.append(prefix + "weight")
|
||||||
|
|
||||||
|
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
|
||||||
|
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
|
||||||
|
missing_keys.append(prefix + "bias")
|
||||||
|
|
||||||
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
class Linear(torch.nn.Linear, CastWeightBiasOp):
|
||||||
|
|
||||||
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
|
||||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
# don't trust subclasses that BYO state dict loader to call us.
|
||||||
|
if (not comfy.model_management.WINDOWS
|
||||||
|
or not comfy.memory_management.aimdo_enabled
|
||||||
|
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
||||||
super().__init__(in_features, out_features, bias, device, dtype)
|
super().__init__(in_features, out_features, bias, device, dtype)
|
||||||
return
|
return
|
||||||
|
|
||||||
@ -330,32 +360,21 @@ class disable_weight_init:
|
|||||||
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
strict, missing_keys, unexpected_keys, error_msgs):
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled:
|
if (not comfy.model_management.WINDOWS
|
||||||
|
or not comfy.memory_management.aimdo_enabled
|
||||||
|
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
|
||||||
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
missing_keys, unexpected_keys, error_msgs)
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
|
disable_weight_init._lazy_load_from_state_dict(
|
||||||
prefix_len = len(prefix)
|
self,
|
||||||
for k,v in state_dict.items():
|
state_dict,
|
||||||
if k[prefix_len:] == "weight":
|
prefix,
|
||||||
if not assign_to_params_buffers:
|
local_metadata,
|
||||||
v = v.clone()
|
missing_keys,
|
||||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
unexpected_keys,
|
||||||
elif k[prefix_len:] == "bias" and v is not None:
|
weight_shape=(self.in_features, self.out_features),
|
||||||
if not assign_to_params_buffers:
|
bias_shape=(self.out_features,),
|
||||||
v = v.clone()
|
)
|
||||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
|
||||||
else:
|
|
||||||
unexpected_keys.append(k)
|
|
||||||
|
|
||||||
#Reconcile default construction of the weight if its missing.
|
|
||||||
if self.weight is None:
|
|
||||||
v = torch.zeros(self.in_features, self.out_features)
|
|
||||||
self.weight = torch.nn.Parameter(v, requires_grad=False)
|
|
||||||
missing_keys.append(prefix+"weight")
|
|
||||||
if self.bias is None and self.comfy_need_lazy_init_bias:
|
|
||||||
v = torch.zeros(self.out_features,)
|
|
||||||
self.bias = torch.nn.Parameter(v, requires_grad=False)
|
|
||||||
missing_keys.append(prefix+"bias")
|
|
||||||
|
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
@ -547,6 +566,53 @@ class disable_weight_init:
|
|||||||
return super().forward(*args, **kwargs)
|
return super().forward(*args, **kwargs)
|
||||||
|
|
||||||
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
class Embedding(torch.nn.Embedding, CastWeightBiasOp):
|
||||||
|
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
|
||||||
|
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
|
||||||
|
_freeze=False, device=None, dtype=None):
|
||||||
|
# don't trust subclasses that BYO state dict loader to call us.
|
||||||
|
if (not comfy.model_management.WINDOWS
|
||||||
|
or not comfy.memory_management.aimdo_enabled
|
||||||
|
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
||||||
|
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
|
||||||
|
norm_type, scale_grad_by_freq, sparse, _weight,
|
||||||
|
_freeze, device, dtype)
|
||||||
|
return
|
||||||
|
|
||||||
|
torch.nn.Module.__init__(self)
|
||||||
|
self.num_embeddings = num_embeddings
|
||||||
|
self.embedding_dim = embedding_dim
|
||||||
|
self.padding_idx = padding_idx
|
||||||
|
self.max_norm = max_norm
|
||||||
|
self.norm_type = norm_type
|
||||||
|
self.scale_grad_by_freq = scale_grad_by_freq
|
||||||
|
self.sparse = sparse
|
||||||
|
# Keep shape/dtype visible for module introspection without reserving storage.
|
||||||
|
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
|
||||||
|
self.weight = torch.nn.Parameter(
|
||||||
|
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
|
||||||
|
requires_grad=False,
|
||||||
|
)
|
||||||
|
self.bias = None
|
||||||
|
self.weight_comfy_model_dtype = dtype
|
||||||
|
|
||||||
|
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
|
||||||
|
strict, missing_keys, unexpected_keys, error_msgs):
|
||||||
|
|
||||||
|
if (not comfy.model_management.WINDOWS
|
||||||
|
or not comfy.memory_management.aimdo_enabled
|
||||||
|
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
|
||||||
|
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
|
||||||
|
missing_keys, unexpected_keys, error_msgs)
|
||||||
|
disable_weight_init._lazy_load_from_state_dict(
|
||||||
|
self,
|
||||||
|
state_dict,
|
||||||
|
prefix,
|
||||||
|
local_metadata,
|
||||||
|
missing_keys,
|
||||||
|
unexpected_keys,
|
||||||
|
weight_shape=(self.num_embeddings, self.embedding_dim),
|
||||||
|
)
|
||||||
|
|
||||||
def reset_parameters(self):
|
def reset_parameters(self):
|
||||||
self.bias = None
|
self.bias = None
|
||||||
return None
|
return None
|
||||||
@ -710,6 +776,71 @@ from .quant_ops import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class QuantLinearFunc(torch.autograd.Function):
|
||||||
|
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
|
||||||
|
Handles any input rank by flattening to 2D for matmul and restoring shape after.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
|
||||||
|
input_shape = input_float.shape
|
||||||
|
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
|
||||||
|
|
||||||
|
# Quantize input (same as inference path)
|
||||||
|
if layout_type is not None:
|
||||||
|
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
|
||||||
|
else:
|
||||||
|
q_input = inp
|
||||||
|
|
||||||
|
w = weight.detach() if weight.requires_grad else weight
|
||||||
|
b = bias.detach() if bias is not None and bias.requires_grad else bias
|
||||||
|
|
||||||
|
output = torch.nn.functional.linear(q_input, w, b)
|
||||||
|
|
||||||
|
# Restore original input shape
|
||||||
|
if len(input_shape) > 2:
|
||||||
|
output = output.unflatten(0, input_shape[:-1])
|
||||||
|
|
||||||
|
ctx.save_for_backward(input_float, weight)
|
||||||
|
ctx.input_shape = input_shape
|
||||||
|
ctx.has_bias = bias is not None
|
||||||
|
ctx.compute_dtype = compute_dtype
|
||||||
|
ctx.weight_requires_grad = weight.requires_grad
|
||||||
|
|
||||||
|
return output
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@torch.autograd.function.once_differentiable
|
||||||
|
def backward(ctx, grad_output):
|
||||||
|
input_float, weight = ctx.saved_tensors
|
||||||
|
compute_dtype = ctx.compute_dtype
|
||||||
|
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
|
||||||
|
|
||||||
|
# Dequantize weight to compute dtype for backward matmul
|
||||||
|
if isinstance(weight, QuantizedTensor):
|
||||||
|
weight_f = weight.dequantize().to(compute_dtype)
|
||||||
|
else:
|
||||||
|
weight_f = weight.to(compute_dtype)
|
||||||
|
|
||||||
|
# grad_input = grad_output @ weight
|
||||||
|
grad_input = torch.mm(grad_2d, weight_f)
|
||||||
|
if len(ctx.input_shape) > 2:
|
||||||
|
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
|
||||||
|
|
||||||
|
# grad_weight (only if weight requires grad, typically frozen for quantized training)
|
||||||
|
grad_weight = None
|
||||||
|
if ctx.weight_requires_grad:
|
||||||
|
input_f = input_float.flatten(0, -2).to(compute_dtype)
|
||||||
|
grad_weight = torch.mm(grad_2d.t(), input_f)
|
||||||
|
|
||||||
|
# grad_bias
|
||||||
|
grad_bias = None
|
||||||
|
if ctx.has_bias:
|
||||||
|
grad_bias = grad_2d.sum(dim=0)
|
||||||
|
|
||||||
|
return grad_input, grad_weight, grad_bias, None, None, None
|
||||||
|
|
||||||
|
|
||||||
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
|
||||||
class MixedPrecisionOps(manual_cast):
|
class MixedPrecisionOps(manual_cast):
|
||||||
_quant_config = quant_config
|
_quant_config = quant_config
|
||||||
@ -801,6 +932,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
orig_shape=(self.out_features, self.in_features),
|
orig_shape=(self.out_features, self.in_features),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
elif self.quant_format == "mxfp8":
|
||||||
|
# MXFP8: E8M0 block scales stored as uint8 in safetensors
|
||||||
|
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
|
||||||
|
dtype=torch.uint8)
|
||||||
|
|
||||||
|
if block_scale is None:
|
||||||
|
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
|
||||||
|
|
||||||
|
block_scale = block_scale.view(torch.float8_e8m0fnu)
|
||||||
|
|
||||||
|
params = layout_cls.Params(
|
||||||
|
scale=block_scale,
|
||||||
|
orig_dtype=MixedPrecisionOps._compute_dtype,
|
||||||
|
orig_shape=(self.out_features, self.in_features),
|
||||||
|
)
|
||||||
|
|
||||||
elif self.quant_format == "nvfp4":
|
elif self.quant_format == "nvfp4":
|
||||||
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
|
||||||
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
|
||||||
@ -888,10 +1035,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
#If cast needs to apply lora, it should be done in the compute dtype
|
#If cast needs to apply lora, it should be done in the compute dtype
|
||||||
compute_dtype = input.dtype
|
compute_dtype = input.dtype
|
||||||
|
|
||||||
if (getattr(self, 'layout_type', None) is not None and
|
_use_quantized = (
|
||||||
|
getattr(self, 'layout_type', None) is not None and
|
||||||
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
|
||||||
not getattr(self, 'comfy_force_cast_weights', False) and
|
not getattr(self, 'comfy_force_cast_weights', False) and
|
||||||
len(self.weight_function) == 0 and len(self.bias_function) == 0):
|
len(self.weight_function) == 0 and len(self.bias_function) == 0
|
||||||
|
)
|
||||||
|
|
||||||
|
# Training path: quantized forward with compute_dtype backward via autograd function
|
||||||
|
if (input.requires_grad and _use_quantized):
|
||||||
|
|
||||||
|
weight, bias, offload_stream = cast_bias_weight(
|
||||||
|
self,
|
||||||
|
input,
|
||||||
|
offloadable=True,
|
||||||
|
compute_dtype=compute_dtype,
|
||||||
|
want_requant=True
|
||||||
|
)
|
||||||
|
|
||||||
|
scale = getattr(self, 'input_scale', None)
|
||||||
|
if scale is not None:
|
||||||
|
scale = comfy.model_management.cast_to_device(scale, input.device, None)
|
||||||
|
|
||||||
|
output = QuantLinearFunc.apply(
|
||||||
|
input, weight, bias, self.layout_type, scale, compute_dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
uncast_bias_weight(self, weight, bias, offload_stream)
|
||||||
|
return output
|
||||||
|
|
||||||
|
# Inference path (unchanged)
|
||||||
|
if _use_quantized:
|
||||||
|
|
||||||
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
|
||||||
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
|
||||||
@ -939,7 +1113,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
for key, param in self._parameters.items():
|
for key, param in self._parameters.items():
|
||||||
if param is None:
|
if param is None:
|
||||||
continue
|
continue
|
||||||
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False))
|
p = fn(param)
|
||||||
|
if p.is_inference():
|
||||||
|
p = p.clone()
|
||||||
|
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
|
||||||
for key, buf in self._buffers.items():
|
for key, buf in self._buffers.items():
|
||||||
if buf is not None:
|
if buf is not None:
|
||||||
self._buffers[key] = fn(buf)
|
self._buffers[key] = fn(buf)
|
||||||
@ -950,12 +1127,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
|
|||||||
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
|
||||||
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
|
||||||
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
|
||||||
|
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
|
||||||
|
|
||||||
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
|
||||||
logging.info("Using mixed precision operations")
|
logging.info("Using mixed precision operations")
|
||||||
disabled = set()
|
disabled = set()
|
||||||
if not nvfp4_compute:
|
if not nvfp4_compute:
|
||||||
disabled.add("nvfp4")
|
disabled.add("nvfp4")
|
||||||
|
if not mxfp8_compute:
|
||||||
|
disabled.add("mxfp8")
|
||||||
if not fp8_compute:
|
if not fp8_compute:
|
||||||
disabled.add("float8_e4m3fn")
|
disabled.add("float8_e4m3fn")
|
||||||
disabled.add("float8_e5m2")
|
disabled.add("float8_e5m2")
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
import torch
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
|
import comfy_aimdo.host_buffer
|
||||||
|
import comfy_aimdo.torch
|
||||||
|
|
||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
|
|
||||||
@ -12,18 +13,31 @@ def pin_memory(module):
|
|||||||
return
|
return
|
||||||
#FIXME: This is a RAM cache trigger event
|
#FIXME: This is a RAM cache trigger event
|
||||||
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
|
||||||
pin = torch.empty((size,), dtype=torch.uint8)
|
|
||||||
if comfy.model_management.pin_memory(pin):
|
if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
|
||||||
module._pin = pin
|
|
||||||
else:
|
|
||||||
module.pin_failed = True
|
module.pin_failed = True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
try:
|
||||||
|
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
|
||||||
|
except RuntimeError:
|
||||||
|
module.pin_failed = True
|
||||||
|
return False
|
||||||
|
|
||||||
|
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
|
||||||
|
module._pin_hostbuf = hostbuf
|
||||||
|
comfy.model_management.TOTAL_PINNED_MEMORY += size
|
||||||
return True
|
return True
|
||||||
|
|
||||||
def unpin_memory(module):
|
def unpin_memory(module):
|
||||||
if get_pin(module) is None:
|
if get_pin(module) is None:
|
||||||
return 0
|
return 0
|
||||||
size = module._pin.numel() * module._pin.element_size()
|
size = module._pin.numel() * module._pin.element_size()
|
||||||
comfy.model_management.unpin_memory(module._pin)
|
|
||||||
|
comfy.model_management.TOTAL_PINNED_MEMORY -= size
|
||||||
|
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
|
||||||
|
comfy.model_management.TOTAL_PINNED_MEMORY = 0
|
||||||
|
|
||||||
del module._pin
|
del module._pin
|
||||||
|
del module._pin_hostbuf
|
||||||
return size
|
return size
|
||||||
|
|||||||
@ -43,6 +43,18 @@ except ImportError as e:
|
|||||||
def get_layout_class(name):
|
def get_layout_class(name):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
_CK_MXFP8_AVAILABLE = False
|
||||||
|
if _CK_AVAILABLE:
|
||||||
|
try:
|
||||||
|
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
|
||||||
|
_CK_MXFP8_AVAILABLE = True
|
||||||
|
except ImportError:
|
||||||
|
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
|
||||||
|
|
||||||
|
if not _CK_MXFP8_AVAILABLE:
|
||||||
|
class _CKMxfp8Layout:
|
||||||
|
pass
|
||||||
|
|
||||||
import comfy.float
|
import comfy.float
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
|
|||||||
return qdata, params
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
|
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
|
||||||
|
@classmethod
|
||||||
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
|
if tensor.dim() != 2:
|
||||||
|
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
|
||||||
|
|
||||||
|
orig_dtype = tensor.dtype
|
||||||
|
orig_shape = tuple(tensor.shape)
|
||||||
|
|
||||||
|
padded_shape = cls.get_padded_shape(orig_shape)
|
||||||
|
needs_padding = padded_shape != orig_shape
|
||||||
|
|
||||||
|
if stochastic_rounding > 0:
|
||||||
|
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
|
||||||
|
else:
|
||||||
|
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
|
||||||
|
|
||||||
|
params = cls.Params(
|
||||||
|
scale=block_scale,
|
||||||
|
orig_dtype=orig_dtype,
|
||||||
|
orig_shape=orig_shape,
|
||||||
|
)
|
||||||
|
return qdata, params
|
||||||
|
|
||||||
|
|
||||||
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
class TensorCoreNVFP4Layout(_CKNvfp4Layout):
|
||||||
@classmethod
|
@classmethod
|
||||||
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
|
||||||
@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
|
|||||||
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
|
||||||
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
|
||||||
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
|
||||||
|
if _CK_MXFP8_AVAILABLE:
|
||||||
|
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
|
||||||
|
|
||||||
QUANT_ALGOS = {
|
QUANT_ALGOS = {
|
||||||
"float8_e4m3fn": {
|
"float8_e4m3fn": {
|
||||||
@ -157,6 +196,14 @@ QUANT_ALGOS = {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if _CK_MXFP8_AVAILABLE:
|
||||||
|
QUANT_ALGOS["mxfp8"] = {
|
||||||
|
"storage_t": torch.float8_e4m3fn,
|
||||||
|
"parameters": {"weight_scale", "input_scale"},
|
||||||
|
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
|
||||||
|
"group_size": 32,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
# Re-exports for backward compatibility
|
# Re-exports for backward compatibility
|
||||||
|
|||||||
@ -8,12 +8,12 @@ import comfy.nested_tensor
|
|||||||
|
|
||||||
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
def prepare_noise_inner(latent_image, generator, noise_inds=None):
|
||||||
if noise_inds is None:
|
if noise_inds is None:
|
||||||
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
|
||||||
|
|
||||||
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
|
||||||
noises = []
|
noises = []
|
||||||
for i in range(unique_inds[-1]+1):
|
for i in range(unique_inds[-1]+1):
|
||||||
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu")
|
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
|
||||||
if i in unique_inds:
|
if i in unique_inds:
|
||||||
noises.append(noise)
|
noises.append(noise)
|
||||||
noises = [noises[i] for i in inverse]
|
noises = [noises[i] for i in inverse]
|
||||||
@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
|
|||||||
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
|
||||||
|
|
||||||
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.to(comfy.model_management.intermediate_device())
|
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
return samples
|
return samples
|
||||||
|
|
||||||
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
|
||||||
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
|
||||||
samples = samples.to(comfy.model_management.intermediate_device())
|
samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
return samples
|
return samples
|
||||||
|
|||||||
@ -985,8 +985,8 @@ class CFGGuider:
|
|||||||
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
|
||||||
device = self.model_patcher.load_device
|
device = self.model_patcher.load_device
|
||||||
|
|
||||||
noise = noise.to(device)
|
noise = noise.to(device=device, dtype=torch.float32)
|
||||||
latent_image = latent_image.to(device)
|
latent_image = latent_image.to(device=device, dtype=torch.float32)
|
||||||
sigmas = sigmas.to(device)
|
sigmas = sigmas.to(device)
|
||||||
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
|
||||||
|
|
||||||
@ -1028,6 +1028,7 @@ class CFGGuider:
|
|||||||
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
|
||||||
else:
|
else:
|
||||||
denoise_mask = denoise_masks[0]
|
denoise_mask = denoise_masks[0]
|
||||||
|
denoise_mask = denoise_mask.float()
|
||||||
|
|
||||||
self.conds = {}
|
self.conds = {}
|
||||||
for k in self.original_conds:
|
for k in self.original_conds:
|
||||||
|
|||||||
55
comfy/sd.py
55
comfy/sd.py
@ -455,7 +455,7 @@ class VAE:
|
|||||||
self.output_channels = 3
|
self.output_channels = 3
|
||||||
self.pad_channel_value = None
|
self.pad_channel_value = None
|
||||||
self.process_input = lambda image: image * 2.0 - 1.0
|
self.process_input = lambda image: image * 2.0 - 1.0
|
||||||
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0)
|
self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
|
||||||
self.working_dtypes = [torch.bfloat16, torch.float32]
|
self.working_dtypes = [torch.bfloat16, torch.float32]
|
||||||
self.disable_offload = False
|
self.disable_offload = False
|
||||||
self.not_video = False
|
self.not_video = False
|
||||||
@ -871,13 +871,16 @@ class VAE:
|
|||||||
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
|
||||||
return pixels
|
return pixels
|
||||||
|
|
||||||
|
def vae_output_dtype(self):
|
||||||
|
return model_management.intermediate_dtype()
|
||||||
|
|
||||||
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
|
||||||
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
|
||||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
|
||||||
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
output = self.process_output(
|
output = self.process_output(
|
||||||
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
|
||||||
@ -887,16 +890,16 @@ class VAE:
|
|||||||
|
|
||||||
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
|
||||||
if samples.ndim == 3:
|
if samples.ndim == 3:
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
else:
|
else:
|
||||||
og_shape = samples.shape
|
og_shape = samples.shape
|
||||||
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
|
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
|
||||||
|
|
||||||
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
|
||||||
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float()
|
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
|
||||||
|
|
||||||
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
|
||||||
@ -905,7 +908,7 @@ class VAE:
|
|||||||
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
|
||||||
pbar = comfy.utils.ProgressBar(steps)
|
pbar = comfy.utils.ProgressBar(steps)
|
||||||
|
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
|
||||||
@ -914,7 +917,7 @@ class VAE:
|
|||||||
|
|
||||||
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
|
||||||
if self.latent_dim == 1:
|
if self.latent_dim == 1:
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
out_channels = self.latent_channels
|
out_channels = self.latent_channels
|
||||||
upscale_amount = 1 / self.downscale_ratio
|
upscale_amount = 1 / self.downscale_ratio
|
||||||
else:
|
else:
|
||||||
@ -923,7 +926,7 @@ class VAE:
|
|||||||
tile_x = tile_x // extra_channel_size
|
tile_x = tile_x // extra_channel_size
|
||||||
overlap = overlap // extra_channel_size
|
overlap = overlap // extra_channel_size
|
||||||
upscale_amount = 1 / self.downscale_ratio
|
upscale_amount = 1 / self.downscale_ratio
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
|
||||||
|
|
||||||
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
|
||||||
if self.latent_dim == 1:
|
if self.latent_dim == 1:
|
||||||
@ -932,7 +935,7 @@ class VAE:
|
|||||||
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
|
||||||
|
|
||||||
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
|
||||||
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float()
|
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
|
||||||
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
|
||||||
|
|
||||||
def decode(self, samples_in, vae_options={}):
|
def decode(self, samples_in, vae_options={}):
|
||||||
@ -948,12 +951,23 @@ class VAE:
|
|||||||
batch_number = int(free_memory / memory_used)
|
batch_number = int(free_memory / memory_used)
|
||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
|
|
||||||
|
# Pre-allocate output for VAEs that support direct buffer writes
|
||||||
|
preallocated = False
|
||||||
|
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||||
|
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
|
preallocated = True
|
||||||
|
|
||||||
for x in range(0, samples_in.shape[0], batch_number):
|
for x in range(0, samples_in.shape[0], batch_number):
|
||||||
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device)
|
samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
|
||||||
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float())
|
if preallocated:
|
||||||
if pixel_samples is None:
|
self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
|
||||||
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
else:
|
||||||
pixel_samples[x:x+batch_number] = out
|
out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
|
||||||
|
if pixel_samples is None:
|
||||||
|
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
|
pixel_samples[x:x+batch_number].copy_(out)
|
||||||
|
del out
|
||||||
|
self.process_output(pixel_samples[x:x+batch_number])
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
model_management.raise_non_oom(e)
|
model_management.raise_non_oom(e)
|
||||||
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
|
||||||
@ -964,6 +978,7 @@ class VAE:
|
|||||||
do_tile = True
|
do_tile = True
|
||||||
|
|
||||||
if do_tile:
|
if do_tile:
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
dims = samples_in.ndim - 2
|
dims = samples_in.ndim - 2
|
||||||
if dims == 1 or self.extra_1d_channel is not None:
|
if dims == 1 or self.extra_1d_channel is not None:
|
||||||
pixel_samples = self.decode_tiled_1d(samples_in)
|
pixel_samples = self.decode_tiled_1d(samples_in)
|
||||||
@ -1024,10 +1039,15 @@ class VAE:
|
|||||||
batch_number = max(1, batch_number)
|
batch_number = max(1, batch_number)
|
||||||
samples = None
|
samples = None
|
||||||
for x in range(0, pixel_samples.shape[0], batch_number):
|
for x in range(0, pixel_samples.shape[0], batch_number):
|
||||||
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device)
|
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
|
||||||
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float()
|
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
|
||||||
|
out = self.first_stage_model.encode(pixels_in, device=self.device)
|
||||||
|
else:
|
||||||
|
pixels_in = pixels_in.to(self.device)
|
||||||
|
out = self.first_stage_model.encode(pixels_in)
|
||||||
|
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
|
||||||
if samples is None:
|
if samples is None:
|
||||||
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device)
|
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
|
||||||
samples[x:x + batch_number] = out
|
samples[x:x + batch_number] = out
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@ -1040,6 +1060,7 @@ class VAE:
|
|||||||
do_tile = True
|
do_tile = True
|
||||||
|
|
||||||
if do_tile:
|
if do_tile:
|
||||||
|
comfy.model_management.soft_empty_cache()
|
||||||
if self.latent_dim == 3:
|
if self.latent_dim == 3:
|
||||||
tile = 256
|
tile = 256
|
||||||
overlap = tile // 4
|
overlap = tile // 4
|
||||||
|
|||||||
@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
|
|||||||
out, pooled = o[:2]
|
out, pooled = o[:2]
|
||||||
|
|
||||||
if pooled is not None:
|
if pooled is not None:
|
||||||
first_pooled = pooled[0:1].to(model_management.intermediate_device())
|
first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
|
||||||
else:
|
else:
|
||||||
first_pooled = pooled
|
first_pooled = pooled
|
||||||
|
|
||||||
@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
|
|||||||
output.append(z)
|
output.append(z)
|
||||||
|
|
||||||
if (len(output) == 0):
|
if (len(output) == 0):
|
||||||
r = (out[-1:].to(model_management.intermediate_device()), first_pooled)
|
r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
|
||||||
else:
|
else:
|
||||||
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled)
|
r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
|
||||||
|
|
||||||
if len(o) > 2:
|
if len(o) > 2:
|
||||||
extra = {}
|
extra = {}
|
||||||
for k in o[2]:
|
for k in o[2]:
|
||||||
v = o[2][k]
|
v = o[2][k]
|
||||||
if k == "attention_mask":
|
if k == "attention_mask":
|
||||||
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device())
|
v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
|
||||||
extra[k] = v
|
extra[k] = v
|
||||||
|
|
||||||
r = r + (extra,)
|
r = r + (extra,)
|
||||||
|
|||||||
@ -20,6 +20,8 @@
|
|||||||
import torch
|
import torch
|
||||||
import math
|
import math
|
||||||
import struct
|
import struct
|
||||||
|
import ctypes
|
||||||
|
import os
|
||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import safetensors.torch
|
import safetensors.torch
|
||||||
import numpy as np
|
import numpy as np
|
||||||
@ -32,7 +34,7 @@ from einops import rearrange
|
|||||||
from comfy.cli_args import args
|
from comfy.cli_args import args
|
||||||
import json
|
import json
|
||||||
import time
|
import time
|
||||||
import mmap
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
|
|
||||||
MMAP_TORCH_FILES = args.mmap_torch_files
|
MMAP_TORCH_FILES = args.mmap_torch_files
|
||||||
@ -81,14 +83,17 @@ _TYPES = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
def load_safetensors(ckpt):
|
def load_safetensors(ckpt):
|
||||||
f = open(ckpt, "rb")
|
import comfy_aimdo.model_mmap
|
||||||
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
|
|
||||||
mv = memoryview(mapping)
|
|
||||||
|
|
||||||
header_size = struct.unpack("<Q", mapping[:8])[0]
|
f = open(ckpt, "rb", buffering=0)
|
||||||
header = json.loads(mapping[8:8+header_size].decode("utf-8"))
|
model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
|
||||||
|
file_size = os.path.getsize(ckpt)
|
||||||
|
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
|
||||||
|
|
||||||
mv = mv[8 + header_size:]
|
header_size = struct.unpack("<Q", mv[:8])[0]
|
||||||
|
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
|
||||||
|
|
||||||
|
mv = mv[(data_base_offset := 8 + header_size):]
|
||||||
|
|
||||||
sd = {}
|
sd = {}
|
||||||
for name, info in header.items():
|
for name, info in header.items():
|
||||||
@ -102,7 +107,14 @@ def load_safetensors(ckpt):
|
|||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
#We are working with read-only RAM by design
|
#We are working with read-only RAM by design
|
||||||
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
warnings.filterwarnings("ignore", message="The given buffer is not writable")
|
||||||
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
|
||||||
|
storage = tensor.untyped_storage()
|
||||||
|
setattr(storage,
|
||||||
|
"_comfy_tensor_file_slice",
|
||||||
|
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
|
||||||
|
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
|
||||||
|
setattr(storage, "_comfy_tensor_mmap_touched", False)
|
||||||
|
sd[name] = tensor
|
||||||
|
|
||||||
return sd, header.get("__metadata__", {}),
|
return sd, header.get("__metadata__", {}),
|
||||||
|
|
||||||
@ -885,6 +897,10 @@ def set_attr(obj, attr, value):
|
|||||||
return prev
|
return prev
|
||||||
|
|
||||||
def set_attr_param(obj, attr, value):
|
def set_attr_param(obj, attr, value):
|
||||||
|
# Clone inference tensors (created under torch.inference_mode) since
|
||||||
|
# their version counter is frozen and nn.Parameter() cannot wrap them.
|
||||||
|
if (not torch.is_inference_mode_enabled()) and value.is_inference():
|
||||||
|
value = value.clone()
|
||||||
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
|
||||||
|
|
||||||
def set_attr_buffer(obj, attr, value):
|
def set_attr_buffer(obj, attr, value):
|
||||||
@ -1119,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
out = output[b:b+1].zero_()
|
||||||
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device)
|
out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
|
||||||
|
|
||||||
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
|
||||||
|
|
||||||
@ -1135,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
upscaled.append(round(get_pos(d, pos)))
|
upscaled.append(round(get_pos(d, pos)))
|
||||||
|
|
||||||
ps = function(s_in).to(output_device)
|
ps = function(s_in).to(output_device)
|
||||||
mask = torch.ones_like(ps)
|
mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
|
||||||
|
|
||||||
for d in range(2, dims + 2):
|
for d in range(2, dims + 2):
|
||||||
feather = round(get_scale(d - 2, overlap[d - 2]))
|
feather = round(get_scale(d - 2, overlap[d - 2]))
|
||||||
@ -1158,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
|
|||||||
if pbar is not None:
|
if pbar is not None:
|
||||||
pbar.update(1)
|
pbar.update(1)
|
||||||
|
|
||||||
output[b:b+1] = out/out_div
|
out.div_(out_div)
|
||||||
return output
|
return output
|
||||||
|
|
||||||
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):
|
||||||
|
|||||||
@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.node_replacement = self.NodeReplacement()
|
self.node_replacement = self.NodeReplacement()
|
||||||
self.execution = self.Execution()
|
self.execution = self.Execution()
|
||||||
|
self.caching = self.Caching()
|
||||||
|
|
||||||
class NodeReplacement(ProxiedSingleton):
|
class NodeReplacement(ProxiedSingleton):
|
||||||
async def register(self, node_replace: io.NodeReplace) -> None:
|
async def register(self, node_replace: io.NodeReplace) -> None:
|
||||||
@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase):
|
|||||||
image=to_display,
|
image=to_display,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
class Caching(ProxiedSingleton):
|
||||||
|
"""
|
||||||
|
External cache provider API for sharing cached node outputs
|
||||||
|
across ComfyUI instances.
|
||||||
|
|
||||||
|
Example::
|
||||||
|
|
||||||
|
from comfy_api.latest import Caching
|
||||||
|
|
||||||
|
class MyCacheProvider(Caching.CacheProvider):
|
||||||
|
async def on_lookup(self, context):
|
||||||
|
... # check external storage
|
||||||
|
|
||||||
|
async def on_store(self, context, value):
|
||||||
|
... # store to external storage
|
||||||
|
|
||||||
|
Caching.register_provider(MyCacheProvider())
|
||||||
|
"""
|
||||||
|
from ._caching import CacheProvider, CacheContext, CacheValue
|
||||||
|
|
||||||
|
async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
|
||||||
|
"""Register an external cache provider. Providers are called in registration order."""
|
||||||
|
from comfy_execution.cache_provider import register_cache_provider
|
||||||
|
register_cache_provider(provider)
|
||||||
|
|
||||||
|
async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
|
||||||
|
"""Unregister a previously registered cache provider."""
|
||||||
|
from comfy_execution.cache_provider import unregister_cache_provider
|
||||||
|
unregister_cache_provider(provider)
|
||||||
|
|
||||||
class ComfyExtension(ABC):
|
class ComfyExtension(ABC):
|
||||||
async def on_load(self) -> None:
|
async def on_load(self) -> None:
|
||||||
"""
|
"""
|
||||||
@ -116,6 +147,9 @@ class Types:
|
|||||||
VOXEL = VOXEL
|
VOXEL = VOXEL
|
||||||
File3D = File3D
|
File3D = File3D
|
||||||
|
|
||||||
|
|
||||||
|
Caching = ComfyAPI_latest.Caching
|
||||||
|
|
||||||
ComfyAPI = ComfyAPI_latest
|
ComfyAPI = ComfyAPI_latest
|
||||||
|
|
||||||
# Create a synchronous version of the API
|
# Create a synchronous version of the API
|
||||||
@ -135,6 +169,7 @@ __all__ = [
|
|||||||
"Input",
|
"Input",
|
||||||
"InputImpl",
|
"InputImpl",
|
||||||
"Types",
|
"Types",
|
||||||
|
"Caching",
|
||||||
"ComfyExtension",
|
"ComfyExtension",
|
||||||
"io",
|
"io",
|
||||||
"IO",
|
"IO",
|
||||||
|
|||||||
42
comfy_api/latest/_caching.py
Normal file
42
comfy_api/latest/_caching.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import Optional
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheContext:
|
||||||
|
node_id: str
|
||||||
|
class_type: str
|
||||||
|
cache_key_hash: str # SHA256 hex digest
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class CacheValue:
|
||||||
|
outputs: list
|
||||||
|
ui: dict = None
|
||||||
|
|
||||||
|
|
||||||
|
class CacheProvider(ABC):
|
||||||
|
"""Abstract base class for external cache providers.
|
||||||
|
Exceptions from provider methods are caught by the caller and never break execution.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||||
|
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||||
|
"""Called after local store. Dispatched via asyncio.create_task."""
|
||||||
|
pass
|
||||||
|
|
||||||
|
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
|
||||||
|
"""Return False to skip external caching for this node. Default: True."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def on_prompt_start(self, prompt_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def on_prompt_end(self, prompt_id: str) -> None:
|
||||||
|
pass
|
||||||
@ -67,6 +67,7 @@ class GeminiPart(BaseModel):
|
|||||||
inlineData: GeminiInlineData | None = Field(None)
|
inlineData: GeminiInlineData | None = Field(None)
|
||||||
fileData: GeminiFileData | None = Field(None)
|
fileData: GeminiFileData | None = Field(None)
|
||||||
text: str | None = Field(None)
|
text: str | None = Field(None)
|
||||||
|
thought: bool | None = Field(None)
|
||||||
|
|
||||||
|
|
||||||
class GeminiTextPart(BaseModel):
|
class GeminiTextPart(BaseModel):
|
||||||
|
|||||||
43
comfy_api_nodes/apis/quiver.py
Normal file
43
comfy_api_nodes/apis/quiver.py
Normal file
@ -0,0 +1,43 @@
|
|||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class QuiverImageObject(BaseModel):
|
||||||
|
url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class QuiverTextToSVGRequest(BaseModel):
|
||||||
|
model: str = Field(default="arrow-preview")
|
||||||
|
prompt: str = Field(...)
|
||||||
|
instructions: str | None = Field(default=None)
|
||||||
|
references: list[QuiverImageObject] | None = Field(default=None, max_length=4)
|
||||||
|
temperature: float | None = Field(default=None, ge=0, le=2)
|
||||||
|
top_p: float | None = Field(default=None, ge=0, le=1)
|
||||||
|
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
|
||||||
|
|
||||||
|
|
||||||
|
class QuiverImageToSVGRequest(BaseModel):
|
||||||
|
model: str = Field(default="arrow-preview")
|
||||||
|
image: QuiverImageObject = Field(...)
|
||||||
|
auto_crop: bool | None = Field(default=None)
|
||||||
|
target_size: int | None = Field(default=None, ge=128, le=4096)
|
||||||
|
temperature: float | None = Field(default=None, ge=0, le=2)
|
||||||
|
top_p: float | None = Field(default=None, ge=0, le=1)
|
||||||
|
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
|
||||||
|
|
||||||
|
|
||||||
|
class QuiverSVGResponseItem(BaseModel):
|
||||||
|
svg: str = Field(...)
|
||||||
|
mime_type: str | None = Field(default="image/svg+xml")
|
||||||
|
|
||||||
|
|
||||||
|
class QuiverSVGUsage(BaseModel):
|
||||||
|
total_tokens: int | None = Field(default=None)
|
||||||
|
input_tokens: int | None = Field(default=None)
|
||||||
|
output_tokens: int | None = Field(default=None)
|
||||||
|
|
||||||
|
|
||||||
|
class QuiverSVGResponse(BaseModel):
|
||||||
|
id: str | None = Field(default=None)
|
||||||
|
created: int | None = Field(default=None)
|
||||||
|
data: list[QuiverSVGResponseItem] = Field(...)
|
||||||
|
usage: QuiverSVGUsage | None = Field(default=None)
|
||||||
@ -47,6 +47,10 @@ SEEDREAM_MODELS = {
|
|||||||
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
|
||||||
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
|
||||||
|
|
||||||
|
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
|
||||||
if response.error:
|
if response.error:
|
||||||
@ -135,6 +139,7 @@ class ByteDanceImageNode(IO.ComfyNode):
|
|||||||
price_badge=IO.PriceBadge(
|
price_badge=IO.PriceBadge(
|
||||||
expr="""{"type":"usd","usd":0.03}""",
|
expr="""{"type":"usd","usd":0.03}""",
|
||||||
),
|
),
|
||||||
|
is_deprecated=True,
|
||||||
)
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -942,7 +947,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
|
|||||||
]
|
]
|
||||||
return await process_video_task(
|
return await process_video_task(
|
||||||
cls,
|
cls,
|
||||||
payload=Image2VideoTaskCreationRequest(model=model, content=x),
|
payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None),
|
||||||
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
|
||||||
)
|
)
|
||||||
|
|
||||||
@ -952,6 +957,12 @@ async def process_video_task(
|
|||||||
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
|
||||||
estimated_duration: int | None,
|
estimated_duration: int | None,
|
||||||
) -> IO.NodeOutput:
|
) -> IO.NodeOutput:
|
||||||
|
if payload.model in DEPRECATED_MODELS:
|
||||||
|
logger.warning(
|
||||||
|
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
|
||||||
|
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
|
||||||
|
payload.model,
|
||||||
|
)
|
||||||
initial_response = await sync_op(
|
initial_response = await sync_op(
|
||||||
cls,
|
cls,
|
||||||
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),
|
||||||
|
|||||||
@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
|
|||||||
$m := widgets.model;
|
$m := widgets.model;
|
||||||
$r := widgets.resolution;
|
$r := widgets.resolution;
|
||||||
$isFlash := $contains($m, "nano banana 2");
|
$isFlash := $contains($m, "nano banana 2");
|
||||||
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123};
|
$flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
|
||||||
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
|
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
|
||||||
$prices := $isFlash ? $flashPrices : $proPrices;
|
$prices := $isFlash ? $flashPrices : $proPrices;
|
||||||
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
|
||||||
@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
|
|||||||
return "\n".join([part.text for part in parts])
|
return "\n".join([part.text for part in parts])
|
||||||
|
|
||||||
|
|
||||||
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image:
|
async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
|
||||||
image_tensors: list[Input.Image] = []
|
image_tensors: list[Input.Image] = []
|
||||||
parts = get_parts_by_type(response, "image/*")
|
parts = get_parts_by_type(response, "image/*")
|
||||||
for part in parts:
|
for part in parts:
|
||||||
|
if (part.thought is True) != thought:
|
||||||
|
continue
|
||||||
if part.inlineData:
|
if part.inlineData:
|
||||||
image_data = base64.b64decode(part.inlineData.data)
|
image_data = base64.b64decode(part.inlineData.data)
|
||||||
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
returned_image = bytesio_to_image_tensor(BytesIO(image_data))
|
||||||
@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
|||||||
outputs=[
|
outputs=[
|
||||||
IO.Image.Output(),
|
IO.Image.Output(),
|
||||||
IO.String.Output(),
|
IO.String.Output(),
|
||||||
|
IO.Image.Output(
|
||||||
|
display_name="thought_image",
|
||||||
|
tooltip="First image from the model's thinking process. "
|
||||||
|
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
|
||||||
|
),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
|
|||||||
response_model=GeminiGenerateContentResponse,
|
response_model=GeminiGenerateContentResponse,
|
||||||
price_extractor=calculate_tokens_price,
|
price_extractor=calculate_tokens_price,
|
||||||
)
|
)
|
||||||
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response))
|
return IO.NodeOutput(
|
||||||
|
await get_image_from_response(response),
|
||||||
|
get_text_from_response(response),
|
||||||
|
await get_image_from_response(response, thought=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GeminiExtension(ComfyExtension):
|
class GeminiExtension(ComfyExtension):
|
||||||
|
|||||||
@ -1,3 +1,7 @@
|
|||||||
|
import zipfile
|
||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
import torch
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||||
@ -17,7 +21,10 @@ from comfy_api_nodes.apis.hunyuan3d import (
|
|||||||
)
|
)
|
||||||
from comfy_api_nodes.util import (
|
from comfy_api_nodes.util import (
|
||||||
ApiEndpoint,
|
ApiEndpoint,
|
||||||
|
bytesio_to_image_tensor,
|
||||||
|
download_url_to_bytesio,
|
||||||
download_url_to_file_3d,
|
download_url_to_file_3d,
|
||||||
|
download_url_to_image_tensor,
|
||||||
downscale_image_tensor_by_max_side,
|
downscale_image_tensor_by_max_side,
|
||||||
poll_op,
|
poll_op,
|
||||||
sync_op,
|
sync_op,
|
||||||
@ -36,6 +43,68 @@ def _is_tencent_rate_limited(status: int, body: object) -> bool:
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ObjZipResult:
|
||||||
|
__slots__ = ("obj", "texture", "metallic", "normal", "roughness")
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
obj: Types.File3D,
|
||||||
|
texture: Input.Image | None = None,
|
||||||
|
metallic: Input.Image | None = None,
|
||||||
|
normal: Input.Image | None = None,
|
||||||
|
roughness: Input.Image | None = None,
|
||||||
|
):
|
||||||
|
self.obj = obj
|
||||||
|
self.texture = texture
|
||||||
|
self.metallic = metallic
|
||||||
|
self.normal = normal
|
||||||
|
self.roughness = roughness
|
||||||
|
|
||||||
|
|
||||||
|
async def download_and_extract_obj_zip(url: str) -> ObjZipResult:
|
||||||
|
"""The Tencent API returns OBJ results as ZIP archives containing the .obj mesh, and texture images.
|
||||||
|
|
||||||
|
When PBR is enabled, the ZIP may contain additional metallic, normal, and roughness maps
|
||||||
|
identified by their filename suffixes.
|
||||||
|
"""
|
||||||
|
data = BytesIO()
|
||||||
|
await download_url_to_bytesio(url, data)
|
||||||
|
data.seek(0)
|
||||||
|
if not zipfile.is_zipfile(data):
|
||||||
|
data.seek(0)
|
||||||
|
return ObjZipResult(obj=Types.File3D(source=data, file_format="obj"))
|
||||||
|
data.seek(0)
|
||||||
|
obj_bytes = None
|
||||||
|
textures: dict[str, Input.Image] = {}
|
||||||
|
with zipfile.ZipFile(data) as zf:
|
||||||
|
for name in zf.namelist():
|
||||||
|
lower = name.lower()
|
||||||
|
if lower.endswith(".obj"):
|
||||||
|
obj_bytes = zf.read(name)
|
||||||
|
elif any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
|
||||||
|
stem = lower.rsplit(".", 1)[0]
|
||||||
|
tensor = bytesio_to_image_tensor(BytesIO(zf.read(name)), mode="RGB")
|
||||||
|
matched_key = "texture"
|
||||||
|
for suffix, key in {
|
||||||
|
"_metallic": "metallic",
|
||||||
|
"_normal": "normal",
|
||||||
|
"_roughness": "roughness",
|
||||||
|
}.items():
|
||||||
|
if stem.endswith(suffix):
|
||||||
|
matched_key = key
|
||||||
|
break
|
||||||
|
textures[matched_key] = tensor
|
||||||
|
if obj_bytes is None:
|
||||||
|
raise ValueError("ZIP archive does not contain an OBJ file.")
|
||||||
|
return ObjZipResult(
|
||||||
|
obj=Types.File3D(source=BytesIO(obj_bytes), file_format="obj"),
|
||||||
|
texture=textures.get("texture"),
|
||||||
|
metallic=textures.get("metallic"),
|
||||||
|
normal=textures.get("normal"),
|
||||||
|
roughness=textures.get("roughness"),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def get_file_from_response(
|
def get_file_from_response(
|
||||||
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
|
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
|
||||||
) -> ResultFile3D | None:
|
) -> ResultFile3D | None:
|
||||||
@ -93,6 +162,7 @@ class TencentTextToModelNode(IO.ComfyNode):
|
|||||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||||
IO.File3DGLB.Output(display_name="GLB"),
|
IO.File3DGLB.Output(display_name="GLB"),
|
||||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||||
|
IO.Image.Output(display_name="texture_image"),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
@ -151,14 +221,14 @@ class TencentTextToModelNode(IO.ComfyNode):
|
|||||||
response_model=To3DProTaskResultResponse,
|
response_model=To3DProTaskResultResponse,
|
||||||
status_extractor=lambda r: r.Status,
|
status_extractor=lambda r: r.Status,
|
||||||
)
|
)
|
||||||
|
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
f"{task_id}.glb",
|
f"{task_id}.glb",
|
||||||
await download_url_to_file_3d(
|
await download_url_to_file_3d(
|
||||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||||
),
|
),
|
||||||
await download_url_to_file_3d(
|
obj_result.obj,
|
||||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
obj_result.texture,
|
||||||
),
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -211,6 +281,10 @@ class TencentImageToModelNode(IO.ComfyNode):
|
|||||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||||
IO.File3DGLB.Output(display_name="GLB"),
|
IO.File3DGLB.Output(display_name="GLB"),
|
||||||
IO.File3DOBJ.Output(display_name="OBJ"),
|
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||||
|
IO.Image.Output(display_name="texture_image"),
|
||||||
|
IO.Image.Output(display_name="optional_metallic"),
|
||||||
|
IO.Image.Output(display_name="optional_normal"),
|
||||||
|
IO.Image.Output(display_name="optional_roughness"),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
@ -304,14 +378,17 @@ class TencentImageToModelNode(IO.ComfyNode):
|
|||||||
response_model=To3DProTaskResultResponse,
|
response_model=To3DProTaskResultResponse,
|
||||||
status_extractor=lambda r: r.Status,
|
status_extractor=lambda r: r.Status,
|
||||||
)
|
)
|
||||||
|
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
f"{task_id}.glb",
|
f"{task_id}.glb",
|
||||||
await download_url_to_file_3d(
|
await download_url_to_file_3d(
|
||||||
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
|
||||||
),
|
),
|
||||||
await download_url_to_file_3d(
|
obj_result.obj,
|
||||||
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id
|
obj_result.texture,
|
||||||
),
|
obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
|
||||||
|
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
|
||||||
|
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -431,7 +508,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
|||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
IO.File3DGLB.Output(display_name="GLB"),
|
IO.File3DGLB.Output(display_name="GLB"),
|
||||||
IO.File3DFBX.Output(display_name="FBX"),
|
IO.File3DOBJ.Output(display_name="OBJ"),
|
||||||
|
IO.Image.Output(display_name="texture_image"),
|
||||||
],
|
],
|
||||||
hidden=[
|
hidden=[
|
||||||
IO.Hidden.auth_token_comfy_org,
|
IO.Hidden.auth_token_comfy_org,
|
||||||
@ -480,7 +558,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
|
|||||||
)
|
)
|
||||||
return IO.NodeOutput(
|
return IO.NodeOutput(
|
||||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
|
||||||
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"),
|
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
|
||||||
|
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "texture_image").Url),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -654,7 +733,7 @@ class TencentHunyuan3DExtension(ComfyExtension):
|
|||||||
TencentTextToModelNode,
|
TencentTextToModelNode,
|
||||||
TencentImageToModelNode,
|
TencentImageToModelNode,
|
||||||
TencentModelTo3DUVNode,
|
TencentModelTo3DUVNode,
|
||||||
# Tencent3DTextureEditNode,
|
Tencent3DTextureEditNode,
|
||||||
Tencent3DPartNode,
|
Tencent3DPartNode,
|
||||||
TencentSmartTopologyNode,
|
TencentSmartTopologyNode,
|
||||||
]
|
]
|
||||||
|
|||||||
@ -1459,6 +1459,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
|
|||||||
node_id="KlingOmniProEditVideoNode",
|
node_id="KlingOmniProEditVideoNode",
|
||||||
display_name="Kling 3.0 Omni Edit Video",
|
display_name="Kling 3.0 Omni Edit Video",
|
||||||
category="api node/video/Kling",
|
category="api node/video/Kling",
|
||||||
|
essentials_category="Video Generation",
|
||||||
description="Edit an existing video with the latest model from Kling.",
|
description="Edit an existing video with the latest model from Kling.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),
|
||||||
|
|||||||
291
comfy_api_nodes/nodes_quiver.py
Normal file
291
comfy_api_nodes/nodes_quiver.py
Normal file
@ -0,0 +1,291 @@
|
|||||||
|
from io import BytesIO
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from comfy_api.latest import IO, ComfyExtension
|
||||||
|
from comfy_api_nodes.apis.quiver import (
|
||||||
|
QuiverImageObject,
|
||||||
|
QuiverImageToSVGRequest,
|
||||||
|
QuiverSVGResponse,
|
||||||
|
QuiverTextToSVGRequest,
|
||||||
|
)
|
||||||
|
from comfy_api_nodes.util import (
|
||||||
|
ApiEndpoint,
|
||||||
|
sync_op,
|
||||||
|
upload_image_to_comfyapi,
|
||||||
|
validate_string,
|
||||||
|
)
|
||||||
|
from comfy_extras.nodes_images import SVG
|
||||||
|
|
||||||
|
|
||||||
|
class QuiverTextToSVGNode(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="QuiverTextToSVGNode",
|
||||||
|
display_name="Quiver Text to SVG",
|
||||||
|
category="api node/image/Quiver",
|
||||||
|
description="Generate an SVG from a text prompt using Quiver AI.",
|
||||||
|
inputs=[
|
||||||
|
IO.String.Input(
|
||||||
|
"prompt",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Text description of the desired SVG output.",
|
||||||
|
),
|
||||||
|
IO.String.Input(
|
||||||
|
"instructions",
|
||||||
|
multiline=True,
|
||||||
|
default="",
|
||||||
|
tooltip="Additional style or formatting guidance.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.Autogrow.Input(
|
||||||
|
"reference_images",
|
||||||
|
template=IO.Autogrow.TemplatePrefix(
|
||||||
|
IO.Image.Input("image"),
|
||||||
|
prefix="ref_",
|
||||||
|
min=0,
|
||||||
|
max=4,
|
||||||
|
),
|
||||||
|
tooltip="Up to 4 reference images to guide the generation.",
|
||||||
|
optional=True,
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"arrow-preview",
|
||||||
|
[
|
||||||
|
IO.Float.Input(
|
||||||
|
"temperature",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
max=2.0,
|
||||||
|
step=0.1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Randomness control. Higher values increase randomness.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"top_p",
|
||||||
|
default=1.0,
|
||||||
|
min=0.05,
|
||||||
|
max=1.0,
|
||||||
|
step=0.05,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Nucleus sampling parameter.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"presence_penalty",
|
||||||
|
default=0.0,
|
||||||
|
min=-2.0,
|
||||||
|
max=2.0,
|
||||||
|
step=0.1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Token presence penalty.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="Model to use for SVG generation.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.SVG.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
expr="""{"type":"usd","usd":0.429}""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
prompt: str,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
instructions: str = None,
|
||||||
|
reference_images: IO.Autogrow.Type = None,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
validate_string(prompt, strip_whitespace=False, min_length=1)
|
||||||
|
|
||||||
|
references = None
|
||||||
|
if reference_images:
|
||||||
|
references = []
|
||||||
|
for key in reference_images:
|
||||||
|
url = await upload_image_to_comfyapi(cls, reference_images[key])
|
||||||
|
references.append(QuiverImageObject(url=url))
|
||||||
|
if len(references) > 4:
|
||||||
|
raise ValueError("Maximum 4 reference images are allowed.")
|
||||||
|
|
||||||
|
instructions_val = instructions.strip() if instructions else None
|
||||||
|
if instructions_val == "":
|
||||||
|
instructions_val = None
|
||||||
|
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"),
|
||||||
|
response_model=QuiverSVGResponse,
|
||||||
|
data=QuiverTextToSVGRequest(
|
||||||
|
model=model["model"],
|
||||||
|
prompt=prompt,
|
||||||
|
instructions=instructions_val,
|
||||||
|
references=references,
|
||||||
|
temperature=model.get("temperature"),
|
||||||
|
top_p=model.get("top_p"),
|
||||||
|
presence_penalty=model.get("presence_penalty"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
|
||||||
|
return IO.NodeOutput(SVG(svg_data))
|
||||||
|
|
||||||
|
|
||||||
|
class QuiverImageToSVGNode(IO.ComfyNode):
|
||||||
|
@classmethod
|
||||||
|
def define_schema(cls):
|
||||||
|
return IO.Schema(
|
||||||
|
node_id="QuiverImageToSVGNode",
|
||||||
|
display_name="Quiver Image to SVG",
|
||||||
|
category="api node/image/Quiver",
|
||||||
|
description="Vectorize a raster image into SVG using Quiver AI.",
|
||||||
|
inputs=[
|
||||||
|
IO.Image.Input(
|
||||||
|
"image",
|
||||||
|
tooltip="Input image to vectorize.",
|
||||||
|
),
|
||||||
|
IO.Boolean.Input(
|
||||||
|
"auto_crop",
|
||||||
|
default=False,
|
||||||
|
tooltip="Automatically crop to the dominant subject.",
|
||||||
|
),
|
||||||
|
IO.DynamicCombo.Input(
|
||||||
|
"model",
|
||||||
|
options=[
|
||||||
|
IO.DynamicCombo.Option(
|
||||||
|
"arrow-preview",
|
||||||
|
[
|
||||||
|
IO.Int.Input(
|
||||||
|
"target_size",
|
||||||
|
default=1024,
|
||||||
|
min=128,
|
||||||
|
max=4096,
|
||||||
|
tooltip="Square resize target in pixels.",
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"temperature",
|
||||||
|
default=1.0,
|
||||||
|
min=0.0,
|
||||||
|
max=2.0,
|
||||||
|
step=0.1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Randomness control. Higher values increase randomness.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"top_p",
|
||||||
|
default=1.0,
|
||||||
|
min=0.05,
|
||||||
|
max=1.0,
|
||||||
|
step=0.05,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Nucleus sampling parameter.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
IO.Float.Input(
|
||||||
|
"presence_penalty",
|
||||||
|
default=0.0,
|
||||||
|
min=-2.0,
|
||||||
|
max=2.0,
|
||||||
|
step=0.1,
|
||||||
|
display_mode=IO.NumberDisplay.slider,
|
||||||
|
tooltip="Token presence penalty.",
|
||||||
|
advanced=True,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
),
|
||||||
|
],
|
||||||
|
tooltip="Model to use for SVG vectorization.",
|
||||||
|
),
|
||||||
|
IO.Int.Input(
|
||||||
|
"seed",
|
||||||
|
default=0,
|
||||||
|
min=0,
|
||||||
|
max=2147483647,
|
||||||
|
control_after_generate=True,
|
||||||
|
tooltip="Seed to determine if node should re-run; "
|
||||||
|
"actual results are nondeterministic regardless of seed.",
|
||||||
|
),
|
||||||
|
],
|
||||||
|
outputs=[
|
||||||
|
IO.SVG.Output(),
|
||||||
|
],
|
||||||
|
hidden=[
|
||||||
|
IO.Hidden.auth_token_comfy_org,
|
||||||
|
IO.Hidden.api_key_comfy_org,
|
||||||
|
IO.Hidden.unique_id,
|
||||||
|
],
|
||||||
|
is_api_node=True,
|
||||||
|
price_badge=IO.PriceBadge(
|
||||||
|
expr="""{"type":"usd","usd":0.429}""",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def execute(
|
||||||
|
cls,
|
||||||
|
image,
|
||||||
|
auto_crop: bool,
|
||||||
|
model: dict,
|
||||||
|
seed: int,
|
||||||
|
) -> IO.NodeOutput:
|
||||||
|
image_url = await upload_image_to_comfyapi(cls, image)
|
||||||
|
|
||||||
|
response = await sync_op(
|
||||||
|
cls,
|
||||||
|
ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"),
|
||||||
|
response_model=QuiverSVGResponse,
|
||||||
|
data=QuiverImageToSVGRequest(
|
||||||
|
model=model["model"],
|
||||||
|
image=QuiverImageObject(url=image_url),
|
||||||
|
auto_crop=auto_crop if auto_crop else None,
|
||||||
|
target_size=model.get("target_size"),
|
||||||
|
temperature=model.get("temperature"),
|
||||||
|
top_p=model.get("top_p"),
|
||||||
|
presence_penalty=model.get("presence_penalty"),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
|
||||||
|
return IO.NodeOutput(SVG(svg_data))
|
||||||
|
|
||||||
|
|
||||||
|
class QuiverExtension(ComfyExtension):
|
||||||
|
@override
|
||||||
|
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||||
|
return [
|
||||||
|
QuiverTextToSVGNode,
|
||||||
|
QuiverImageToSVGNode,
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
async def comfy_entrypoint() -> QuiverExtension:
|
||||||
|
return QuiverExtension()
|
||||||
@ -833,6 +833,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
|
|||||||
node_id="RecraftVectorizeImageNode",
|
node_id="RecraftVectorizeImageNode",
|
||||||
display_name="Recraft Vectorize Image",
|
display_name="Recraft Vectorize Image",
|
||||||
category="api node/image/Recraft",
|
category="api node/image/Recraft",
|
||||||
|
essentials_category="Image Tools",
|
||||||
description="Generates SVG synchronously from an input image.",
|
description="Generates SVG synchronously from an input image.",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
|
|||||||
138
comfy_execution/cache_provider.py
Normal file
138
comfy_execution/cache_provider.py
Normal file
@ -0,0 +1,138 @@
|
|||||||
|
from typing import Any, Optional, Tuple, List
|
||||||
|
import hashlib
|
||||||
|
import json
|
||||||
|
import logging
|
||||||
|
import threading
|
||||||
|
|
||||||
|
# Public types — source of truth is comfy_api.latest._caching
|
||||||
|
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
|
||||||
|
|
||||||
|
_logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
_providers: List[CacheProvider] = []
|
||||||
|
_providers_lock = threading.Lock()
|
||||||
|
_providers_snapshot: Tuple[CacheProvider, ...] = ()
|
||||||
|
|
||||||
|
|
||||||
|
def register_cache_provider(provider: CacheProvider) -> None:
|
||||||
|
"""Register an external cache provider. Providers are called in registration order."""
|
||||||
|
global _providers_snapshot
|
||||||
|
with _providers_lock:
|
||||||
|
if provider in _providers:
|
||||||
|
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
|
||||||
|
return
|
||||||
|
_providers.append(provider)
|
||||||
|
_providers_snapshot = tuple(_providers)
|
||||||
|
_logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
|
||||||
|
|
||||||
|
|
||||||
|
def unregister_cache_provider(provider: CacheProvider) -> None:
|
||||||
|
global _providers_snapshot
|
||||||
|
with _providers_lock:
|
||||||
|
try:
|
||||||
|
_providers.remove(provider)
|
||||||
|
_providers_snapshot = tuple(_providers)
|
||||||
|
_logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
|
||||||
|
except ValueError:
|
||||||
|
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
|
||||||
|
|
||||||
|
|
||||||
|
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
|
||||||
|
return _providers_snapshot
|
||||||
|
|
||||||
|
|
||||||
|
def _has_cache_providers() -> bool:
|
||||||
|
return bool(_providers_snapshot)
|
||||||
|
|
||||||
|
|
||||||
|
def _clear_cache_providers() -> None:
|
||||||
|
global _providers_snapshot
|
||||||
|
with _providers_lock:
|
||||||
|
_providers.clear()
|
||||||
|
_providers_snapshot = ()
|
||||||
|
|
||||||
|
|
||||||
|
def _canonicalize(obj: Any) -> Any:
|
||||||
|
# Convert to canonical JSON-serializable form with deterministic ordering.
|
||||||
|
# Frozensets have non-deterministic iteration order between Python sessions.
|
||||||
|
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
|
||||||
|
# _serialize_cache_key returns None and external caching is skipped.
|
||||||
|
if isinstance(obj, frozenset):
|
||||||
|
return ("__frozenset__", sorted(
|
||||||
|
[_canonicalize(item) for item in obj],
|
||||||
|
key=lambda x: json.dumps(x, sort_keys=True)
|
||||||
|
))
|
||||||
|
elif isinstance(obj, set):
|
||||||
|
return ("__set__", sorted(
|
||||||
|
[_canonicalize(item) for item in obj],
|
||||||
|
key=lambda x: json.dumps(x, sort_keys=True)
|
||||||
|
))
|
||||||
|
elif isinstance(obj, tuple):
|
||||||
|
return ("__tuple__", [_canonicalize(item) for item in obj])
|
||||||
|
elif isinstance(obj, list):
|
||||||
|
return [_canonicalize(item) for item in obj]
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
return {"__dict__": sorted(
|
||||||
|
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
|
||||||
|
key=lambda x: json.dumps(x, sort_keys=True)
|
||||||
|
)}
|
||||||
|
elif isinstance(obj, (int, float, str, bool, type(None))):
|
||||||
|
return (type(obj).__name__, obj)
|
||||||
|
elif isinstance(obj, bytes):
|
||||||
|
return ("__bytes__", obj.hex())
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
|
||||||
|
|
||||||
|
|
||||||
|
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
|
||||||
|
# Returns deterministic SHA256 hex digest, or None on failure.
|
||||||
|
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
|
||||||
|
try:
|
||||||
|
canonical = _canonicalize(cache_key)
|
||||||
|
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
|
||||||
|
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f"Failed to serialize cache key: {e}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
def _contains_self_unequal(obj: Any) -> bool:
|
||||||
|
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
|
||||||
|
# never hit locally, but serialized form would match externally. Skip these.
|
||||||
|
try:
|
||||||
|
if not (obj == obj):
|
||||||
|
return True
|
||||||
|
except Exception:
|
||||||
|
return True
|
||||||
|
if isinstance(obj, (frozenset, tuple, list, set)):
|
||||||
|
return any(_contains_self_unequal(item) for item in obj)
|
||||||
|
if isinstance(obj, dict):
|
||||||
|
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
|
||||||
|
if hasattr(obj, 'value'):
|
||||||
|
return _contains_self_unequal(obj.value)
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def _estimate_value_size(value: CacheValue) -> int:
|
||||||
|
try:
|
||||||
|
import torch
|
||||||
|
except ImportError:
|
||||||
|
return 0
|
||||||
|
|
||||||
|
total = 0
|
||||||
|
|
||||||
|
def estimate(obj):
|
||||||
|
nonlocal total
|
||||||
|
if isinstance(obj, torch.Tensor):
|
||||||
|
total += obj.numel() * obj.element_size()
|
||||||
|
elif isinstance(obj, dict):
|
||||||
|
for v in obj.values():
|
||||||
|
estimate(v)
|
||||||
|
elif isinstance(obj, (list, tuple)):
|
||||||
|
for item in obj:
|
||||||
|
estimate(item)
|
||||||
|
|
||||||
|
for output in value.outputs:
|
||||||
|
estimate(output)
|
||||||
|
return total
|
||||||
@ -1,3 +1,4 @@
|
|||||||
|
import asyncio
|
||||||
import bisect
|
import bisect
|
||||||
import gc
|
import gc
|
||||||
import itertools
|
import itertools
|
||||||
@ -147,13 +148,15 @@ class CacheKeySetInputSignature(CacheKeySet):
|
|||||||
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
|
||||||
|
|
||||||
class BasicCache:
|
class BasicCache:
|
||||||
def __init__(self, key_class):
|
def __init__(self, key_class, enable_providers=False):
|
||||||
self.key_class = key_class
|
self.key_class = key_class
|
||||||
self.initialized = False
|
self.initialized = False
|
||||||
|
self.enable_providers = enable_providers
|
||||||
self.dynprompt: DynamicPrompt
|
self.dynprompt: DynamicPrompt
|
||||||
self.cache_key_set: CacheKeySet
|
self.cache_key_set: CacheKeySet
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.subcaches = {}
|
self.subcaches = {}
|
||||||
|
self._pending_store_tasks: set = set()
|
||||||
|
|
||||||
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
|
||||||
self.dynprompt = dynprompt
|
self.dynprompt = dynprompt
|
||||||
@ -196,18 +199,138 @@ class BasicCache:
|
|||||||
def poll(self, **kwargs):
|
def poll(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _set_immediate(self, node_id, value):
|
def get_local(self, node_id):
|
||||||
assert self.initialized
|
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
|
||||||
self.cache[cache_key] = value
|
|
||||||
|
|
||||||
def _get_immediate(self, node_id):
|
|
||||||
if not self.initialized:
|
if not self.initialized:
|
||||||
return None
|
return None
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
if cache_key in self.cache:
|
if cache_key in self.cache:
|
||||||
return self.cache[cache_key]
|
return self.cache[cache_key]
|
||||||
else:
|
return None
|
||||||
|
|
||||||
|
def set_local(self, node_id, value):
|
||||||
|
assert self.initialized
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
self.cache[cache_key] = value
|
||||||
|
|
||||||
|
async def _set_immediate(self, node_id, value):
|
||||||
|
assert self.initialized
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
self.cache[cache_key] = value
|
||||||
|
|
||||||
|
await self._notify_providers_store(node_id, cache_key, value)
|
||||||
|
|
||||||
|
async def _get_immediate(self, node_id):
|
||||||
|
if not self.initialized:
|
||||||
|
return None
|
||||||
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
|
|
||||||
|
if cache_key in self.cache:
|
||||||
|
return self.cache[cache_key]
|
||||||
|
|
||||||
|
external_result = await self._check_providers_lookup(node_id, cache_key)
|
||||||
|
if external_result is not None:
|
||||||
|
self.cache[cache_key] = external_result
|
||||||
|
return external_result
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _notify_providers_store(self, node_id, cache_key, value):
|
||||||
|
from comfy_execution.cache_provider import (
|
||||||
|
_has_cache_providers, _get_cache_providers,
|
||||||
|
CacheValue, _contains_self_unequal, _logger
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.enable_providers:
|
||||||
|
return
|
||||||
|
if not _has_cache_providers():
|
||||||
|
return
|
||||||
|
if not self._is_external_cacheable_value(value):
|
||||||
|
return
|
||||||
|
if _contains_self_unequal(cache_key):
|
||||||
|
return
|
||||||
|
|
||||||
|
context = self._build_context(node_id, cache_key)
|
||||||
|
if context is None:
|
||||||
|
return
|
||||||
|
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
|
||||||
|
|
||||||
|
for provider in _get_cache_providers():
|
||||||
|
try:
|
||||||
|
if provider.should_cache(context, cache_value):
|
||||||
|
task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
|
||||||
|
self._pending_store_tasks.add(task)
|
||||||
|
task.add_done_callback(self._pending_store_tasks.discard)
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def _safe_provider_store(provider, context, cache_value):
|
||||||
|
from comfy_execution.cache_provider import _logger
|
||||||
|
try:
|
||||||
|
await provider.on_store(context, cache_value)
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
|
||||||
|
|
||||||
|
async def _check_providers_lookup(self, node_id, cache_key):
|
||||||
|
from comfy_execution.cache_provider import (
|
||||||
|
_has_cache_providers, _get_cache_providers,
|
||||||
|
CacheValue, _contains_self_unequal, _logger
|
||||||
|
)
|
||||||
|
|
||||||
|
if not self.enable_providers:
|
||||||
|
return None
|
||||||
|
if not _has_cache_providers():
|
||||||
|
return None
|
||||||
|
if _contains_self_unequal(cache_key):
|
||||||
|
return None
|
||||||
|
|
||||||
|
context = self._build_context(node_id, cache_key)
|
||||||
|
if context is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
for provider in _get_cache_providers():
|
||||||
|
try:
|
||||||
|
if not provider.should_cache(context):
|
||||||
|
continue
|
||||||
|
result = await provider.on_lookup(context)
|
||||||
|
if result is not None:
|
||||||
|
if not isinstance(result, CacheValue):
|
||||||
|
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
|
||||||
|
continue
|
||||||
|
if not isinstance(result.outputs, (list, tuple)):
|
||||||
|
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
|
||||||
|
continue
|
||||||
|
from execution import CacheEntry
|
||||||
|
return CacheEntry(ui=result.ui, outputs=list(result.outputs))
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _is_external_cacheable_value(self, value):
|
||||||
|
return hasattr(value, 'outputs') and hasattr(value, 'ui')
|
||||||
|
|
||||||
|
def _get_class_type(self, node_id):
|
||||||
|
if not self.initialized or not self.dynprompt:
|
||||||
|
return ''
|
||||||
|
try:
|
||||||
|
return self.dynprompt.get_node(node_id).get('class_type', '')
|
||||||
|
except Exception:
|
||||||
|
return ''
|
||||||
|
|
||||||
|
def _build_context(self, node_id, cache_key):
|
||||||
|
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
|
||||||
|
try:
|
||||||
|
cache_key_hash = _serialize_cache_key(cache_key)
|
||||||
|
if cache_key_hash is None:
|
||||||
|
return None
|
||||||
|
return CacheContext(
|
||||||
|
node_id=node_id,
|
||||||
|
class_type=self._get_class_type(node_id),
|
||||||
|
cache_key_hash=cache_key_hash,
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
_logger.warning(f"Failed to build cache context for node {node_id}: {e}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def _ensure_subcache(self, node_id, children_ids):
|
async def _ensure_subcache(self, node_id, children_ids):
|
||||||
@ -236,8 +359,8 @@ class BasicCache:
|
|||||||
return result
|
return result
|
||||||
|
|
||||||
class HierarchicalCache(BasicCache):
|
class HierarchicalCache(BasicCache):
|
||||||
def __init__(self, key_class):
|
def __init__(self, key_class, enable_providers=False):
|
||||||
super().__init__(key_class)
|
super().__init__(key_class, enable_providers=enable_providers)
|
||||||
|
|
||||||
def _get_cache_for(self, node_id):
|
def _get_cache_for(self, node_id):
|
||||||
assert self.dynprompt is not None
|
assert self.dynprompt is not None
|
||||||
@ -257,16 +380,27 @@ class HierarchicalCache(BasicCache):
|
|||||||
return None
|
return None
|
||||||
return cache
|
return cache
|
||||||
|
|
||||||
def get(self, node_id):
|
async def get(self, node_id):
|
||||||
cache = self._get_cache_for(node_id)
|
cache = self._get_cache_for(node_id)
|
||||||
if cache is None:
|
if cache is None:
|
||||||
return None
|
return None
|
||||||
return cache._get_immediate(node_id)
|
return await cache._get_immediate(node_id)
|
||||||
|
|
||||||
def set(self, node_id, value):
|
def get_local(self, node_id):
|
||||||
|
cache = self._get_cache_for(node_id)
|
||||||
|
if cache is None:
|
||||||
|
return None
|
||||||
|
return BasicCache.get_local(cache, node_id)
|
||||||
|
|
||||||
|
async def set(self, node_id, value):
|
||||||
cache = self._get_cache_for(node_id)
|
cache = self._get_cache_for(node_id)
|
||||||
assert cache is not None
|
assert cache is not None
|
||||||
cache._set_immediate(node_id, value)
|
await cache._set_immediate(node_id, value)
|
||||||
|
|
||||||
|
def set_local(self, node_id, value):
|
||||||
|
cache = self._get_cache_for(node_id)
|
||||||
|
assert cache is not None
|
||||||
|
BasicCache.set_local(cache, node_id, value)
|
||||||
|
|
||||||
async def ensure_subcache_for(self, node_id, children_ids):
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
cache = self._get_cache_for(node_id)
|
cache = self._get_cache_for(node_id)
|
||||||
@ -287,18 +421,24 @@ class NullCache:
|
|||||||
def poll(self, **kwargs):
|
def poll(self, **kwargs):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def get(self, node_id):
|
async def get(self, node_id):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
def set(self, node_id, value):
|
def get_local(self, node_id):
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def set(self, node_id, value):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def set_local(self, node_id, value):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def ensure_subcache_for(self, node_id, children_ids):
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
class LRUCache(BasicCache):
|
class LRUCache(BasicCache):
|
||||||
def __init__(self, key_class, max_size=100):
|
def __init__(self, key_class, max_size=100, enable_providers=False):
|
||||||
super().__init__(key_class)
|
super().__init__(key_class, enable_providers=enable_providers)
|
||||||
self.max_size = max_size
|
self.max_size = max_size
|
||||||
self.min_generation = 0
|
self.min_generation = 0
|
||||||
self.generation = 0
|
self.generation = 0
|
||||||
@ -322,18 +462,18 @@ class LRUCache(BasicCache):
|
|||||||
del self.children[key]
|
del self.children[key]
|
||||||
self._clean_subcaches()
|
self._clean_subcaches()
|
||||||
|
|
||||||
def get(self, node_id):
|
async def get(self, node_id):
|
||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
return self._get_immediate(node_id)
|
return await self._get_immediate(node_id)
|
||||||
|
|
||||||
def _mark_used(self, node_id):
|
def _mark_used(self, node_id):
|
||||||
cache_key = self.cache_key_set.get_data_key(node_id)
|
cache_key = self.cache_key_set.get_data_key(node_id)
|
||||||
if cache_key is not None:
|
if cache_key is not None:
|
||||||
self.used_generation[cache_key] = self.generation
|
self.used_generation[cache_key] = self.generation
|
||||||
|
|
||||||
def set(self, node_id, value):
|
async def set(self, node_id, value):
|
||||||
self._mark_used(node_id)
|
self._mark_used(node_id)
|
||||||
return self._set_immediate(node_id, value)
|
return await self._set_immediate(node_id, value)
|
||||||
|
|
||||||
async def ensure_subcache_for(self, node_id, children_ids):
|
async def ensure_subcache_for(self, node_id, children_ids):
|
||||||
# Just uses subcaches for tracking 'live' nodes
|
# Just uses subcaches for tracking 'live' nodes
|
||||||
@ -366,20 +506,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
|
|||||||
|
|
||||||
class RAMPressureCache(LRUCache):
|
class RAMPressureCache(LRUCache):
|
||||||
|
|
||||||
def __init__(self, key_class):
|
def __init__(self, key_class, enable_providers=False):
|
||||||
super().__init__(key_class, 0)
|
super().__init__(key_class, 0, enable_providers=enable_providers)
|
||||||
self.timestamps = {}
|
self.timestamps = {}
|
||||||
|
|
||||||
def clean_unused(self):
|
def clean_unused(self):
|
||||||
self._clean_subcaches()
|
self._clean_subcaches()
|
||||||
|
|
||||||
def set(self, node_id, value):
|
async def set(self, node_id, value):
|
||||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
super().set(node_id, value)
|
await super().set(node_id, value)
|
||||||
|
|
||||||
def get(self, node_id):
|
async def get(self, node_id):
|
||||||
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
|
||||||
return super().get(node_id)
|
return await super().get(node_id)
|
||||||
|
|
||||||
def poll(self, ram_headroom):
|
def poll(self, ram_headroom):
|
||||||
def _ram_gb():
|
def _ram_gb():
|
||||||
|
|||||||
@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
|
|||||||
self.execution_cache_listeners = {}
|
self.execution_cache_listeners = {}
|
||||||
|
|
||||||
def is_cached(self, node_id):
|
def is_cached(self, node_id):
|
||||||
return self.output_cache.get(node_id) is not None
|
return self.output_cache.get_local(node_id) is not None
|
||||||
|
|
||||||
def cache_link(self, from_node_id, to_node_id):
|
def cache_link(self, from_node_id, to_node_id):
|
||||||
if to_node_id not in self.execution_cache:
|
if to_node_id not in self.execution_cache:
|
||||||
self.execution_cache[to_node_id] = {}
|
self.execution_cache[to_node_id] = {}
|
||||||
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id)
|
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
|
||||||
if from_node_id not in self.execution_cache_listeners:
|
if from_node_id not in self.execution_cache_listeners:
|
||||||
self.execution_cache_listeners[from_node_id] = set()
|
self.execution_cache_listeners[from_node_id] = set()
|
||||||
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
self.execution_cache_listeners[from_node_id].add(to_node_id)
|
||||||
@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
|
|||||||
if value is None:
|
if value is None:
|
||||||
return None
|
return None
|
||||||
#Write back to the main cache on touch.
|
#Write back to the main cache on touch.
|
||||||
self.output_cache.set(from_node_id, value)
|
self.output_cache.set_local(from_node_id, value)
|
||||||
return value
|
return value
|
||||||
|
|
||||||
def cache_update(self, node_id, value):
|
def cache_update(self, node_id, value):
|
||||||
|
|||||||
@ -19,6 +19,7 @@ class EmptyLatentAudio(IO.ComfyNode):
|
|||||||
node_id="EmptyLatentAudio",
|
node_id="EmptyLatentAudio",
|
||||||
display_name="Empty Latent Audio",
|
display_name="Empty Latent Audio",
|
||||||
category="latent/audio",
|
category="latent/audio",
|
||||||
|
essentials_category="Audio",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
|
||||||
IO.Int.Input(
|
IO.Int.Input(
|
||||||
@ -185,6 +186,7 @@ class SaveAudioMP3(IO.ComfyNode):
|
|||||||
search_aliases=["export mp3"],
|
search_aliases=["export mp3"],
|
||||||
display_name="Save Audio (MP3)",
|
display_name="Save Audio (MP3)",
|
||||||
category="audio",
|
category="audio",
|
||||||
|
essentials_category="Audio",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Audio.Input("audio"),
|
IO.Audio.Input("audio"),
|
||||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||||
|
|||||||
@ -3,6 +3,7 @@ from typing_extensions import override
|
|||||||
|
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
from comfy_api.latest import ComfyExtension, io
|
from comfy_api.latest import ComfyExtension, io
|
||||||
|
import torch
|
||||||
|
|
||||||
|
|
||||||
class Canny(io.ComfyNode):
|
class Canny(io.ComfyNode):
|
||||||
@ -29,8 +30,8 @@ class Canny(io.ComfyNode):
|
|||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
|
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
|
||||||
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold)
|
output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold)
|
||||||
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1)
|
img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1)
|
||||||
return io.NodeOutput(img_out)
|
return io.NodeOutput(img_out)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode):
|
|||||||
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
|
||||||
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
|
||||||
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
|
||||||
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
|
||||||
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
|
||||||
],
|
],
|
||||||
outputs=[
|
outputs=[
|
||||||
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
io.Model.Output(tooltip="The model with context windows applied during sampling."),
|
||||||
|
|||||||
@ -14,6 +14,7 @@ class ImageCompare(IO.ComfyNode):
|
|||||||
display_name="Image Compare",
|
display_name="Image Compare",
|
||||||
description="Compares two images side by side with a slider.",
|
description="Compares two images side by side with a slider.",
|
||||||
category="image",
|
category="image",
|
||||||
|
essentials_category="Image Tools",
|
||||||
is_experimental=True,
|
is_experimental=True,
|
||||||
is_output_node=True,
|
is_output_node=True,
|
||||||
inputs=[
|
inputs=[
|
||||||
|
|||||||
@ -58,6 +58,7 @@ class ImageCropV2(IO.ComfyNode):
|
|||||||
search_aliases=["trim"],
|
search_aliases=["trim"],
|
||||||
display_name="Image Crop",
|
display_name="Image Crop",
|
||||||
category="image/transform",
|
category="image/transform",
|
||||||
|
essentials_category="Image Tools",
|
||||||
inputs=[
|
inputs=[
|
||||||
IO.Image.Input("image"),
|
IO.Image.Input("image"),
|
||||||
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
IO.BoundingBox.Input("crop_region", component="ImageCrop"),
|
||||||
|
|||||||
@ -21,6 +21,7 @@ class Blend(io.ComfyNode):
|
|||||||
node_id="ImageBlend",
|
node_id="ImageBlend",
|
||||||
display_name="Image Blend",
|
display_name="Image Blend",
|
||||||
category="image/postprocessing",
|
category="image/postprocessing",
|
||||||
|
essentials_category="Image Tools",
|
||||||
inputs=[
|
inputs=[
|
||||||
io.Image.Input("image1"),
|
io.Image.Input("image1"),
|
||||||
io.Image.Input("image2"),
|
io.Image.Input("image2"),
|
||||||
|
|||||||
@ -15,6 +15,7 @@ import comfy.sampler_helpers
|
|||||||
import comfy.sd
|
import comfy.sd
|
||||||
import comfy.utils
|
import comfy.utils
|
||||||
import comfy.model_management
|
import comfy.model_management
|
||||||
|
from comfy.cli_args import args, PerformanceFeature
|
||||||
import comfy_extras.nodes_custom_sampler
|
import comfy_extras.nodes_custom_sampler
|
||||||
import folder_paths
|
import folder_paths
|
||||||
import node_helpers
|
import node_helpers
|
||||||
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
training_dtype=torch.bfloat16,
|
training_dtype=torch.bfloat16,
|
||||||
real_dataset=None,
|
real_dataset=None,
|
||||||
bucket_latents=None,
|
bucket_latents=None,
|
||||||
|
use_grad_scaler=False,
|
||||||
):
|
):
|
||||||
self.loss_fn = loss_fn
|
self.loss_fn = loss_fn
|
||||||
self.optimizer = optimizer
|
self.optimizer = optimizer
|
||||||
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self.bucket_latents: list[torch.Tensor] | None = (
|
self.bucket_latents: list[torch.Tensor] | None = (
|
||||||
bucket_latents # list of (Bi, C, Hi, Wi)
|
bucket_latents # list of (Bi, C, Hi, Wi)
|
||||||
)
|
)
|
||||||
|
# GradScaler for fp16 training
|
||||||
|
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
|
||||||
# Precompute bucket offsets and weights for sampling
|
# Precompute bucket offsets and weights for sampling
|
||||||
if bucket_latents is not None:
|
if bucket_latents is not None:
|
||||||
self._init_bucket_data(bucket_latents)
|
self._init_bucket_data(bucket_latents)
|
||||||
@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
batch_sigmas.requires_grad_(True),
|
batch_sigmas.requires_grad_(True),
|
||||||
**batch_extra_args,
|
**batch_extra_args,
|
||||||
)
|
)
|
||||||
loss = self.loss_fn(x0_pred, x0)
|
loss = self.loss_fn(x0_pred.float(), x0.float())
|
||||||
if bwd:
|
if bwd:
|
||||||
bwd_loss = loss / self.grad_acc
|
bwd_loss = loss / self.grad_acc
|
||||||
bwd_loss.backward()
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.scale(bwd_loss).backward()
|
||||||
|
else:
|
||||||
|
bwd_loss.backward()
|
||||||
return loss
|
return loss
|
||||||
|
|
||||||
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
def _generate_batch_sigmas(self, model_wrap, batch_size, device):
|
||||||
@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
)
|
)
|
||||||
total_loss += loss
|
total_loss += loss
|
||||||
total_loss = total_loss / self.grad_acc / len(indicies)
|
total_loss = total_loss / self.grad_acc / len(indicies)
|
||||||
total_loss.backward()
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.scale(total_loss).backward()
|
||||||
|
else:
|
||||||
|
total_loss.backward()
|
||||||
if self.loss_callback:
|
if self.loss_callback:
|
||||||
self.loss_callback(total_loss.item())
|
self.loss_callback(total_loss.item())
|
||||||
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
|
||||||
@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
|
|||||||
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
|
||||||
|
|
||||||
if (i + 1) % self.grad_acc == 0:
|
if (i + 1) % self.grad_acc == 0:
|
||||||
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.unscale_(self.optimizer)
|
||||||
for param_groups in self.optimizer.param_groups:
|
for param_groups in self.optimizer.param_groups:
|
||||||
for param in param_groups["params"]:
|
for param in param_groups["params"]:
|
||||||
if param.grad is None:
|
if param.grad is None:
|
||||||
continue
|
continue
|
||||||
param.grad.data = param.grad.data.to(param.data.dtype)
|
param.grad.data = param.grad.data.to(param.data.dtype)
|
||||||
self.optimizer.step()
|
if self.grad_scaler is not None:
|
||||||
|
self.grad_scaler.step(self.optimizer)
|
||||||
|
self.grad_scaler.update()
|
||||||
|
else:
|
||||||
|
self.optimizer.step()
|
||||||
self.optimizer.zero_grad()
|
self.optimizer.zero_grad()
|
||||||
ui_pbar.update(1)
|
ui_pbar.update(1)
|
||||||
torch.cuda.empty_cache()
|
torch.cuda.empty_cache()
|
||||||
@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"training_dtype",
|
"training_dtype",
|
||||||
options=["bf16", "fp32"],
|
options=["bf16", "fp32", "none"],
|
||||||
default="bf16",
|
default="bf16",
|
||||||
tooltip="The dtype to use for training.",
|
tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
|
||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"lora_dtype",
|
"lora_dtype",
|
||||||
@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
io.Boolean.Input(
|
io.Boolean.Input(
|
||||||
"offloading",
|
"offloading",
|
||||||
default=False,
|
default=False,
|
||||||
tooltip="Offload the Model to RAM. Requires Bypass Mode.",
|
tooltip="Offload model weights to CPU during training to save GPU memory.",
|
||||||
),
|
),
|
||||||
io.Combo.Input(
|
io.Combo.Input(
|
||||||
"existing_lora",
|
"existing_lora",
|
||||||
@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
|
|
||||||
# Setup model and dtype
|
# Setup model and dtype
|
||||||
mp = model.clone()
|
mp = model.clone()
|
||||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
use_grad_scaler = False
|
||||||
|
if training_dtype != "none":
|
||||||
|
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||||
|
mp.set_model_compute_dtype(dtype)
|
||||||
|
else:
|
||||||
|
# Detect model's native dtype for autocast
|
||||||
|
model_dtype = mp.model.get_dtype()
|
||||||
|
if model_dtype == torch.float16:
|
||||||
|
dtype = torch.float16
|
||||||
|
use_grad_scaler = True
|
||||||
|
# Warn about fp16 accumulation instability during training
|
||||||
|
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||||
|
logging.warning(
|
||||||
|
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||||
|
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||||
|
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||||
|
dtype = torch.bfloat16
|
||||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||||
mp.set_model_compute_dtype(dtype)
|
|
||||||
|
|
||||||
if mp.is_dynamic():
|
|
||||||
if not bypass_mode:
|
|
||||||
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
|
|
||||||
bypass_mode = True
|
|
||||||
offloading = True
|
|
||||||
elif offloading:
|
|
||||||
if not bypass_mode:
|
|
||||||
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
|
|
||||||
|
|
||||||
# Prepare latents and compute counts
|
# Prepare latents and compute counts
|
||||||
|
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||||
latents, dtype, bucket_mode
|
latents, latents_dtype, bucket_mode
|
||||||
)
|
)
|
||||||
|
|
||||||
# Validate and expand conditioning
|
# Validate and expand conditioning
|
||||||
@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
training_dtype=dtype,
|
training_dtype=dtype,
|
||||||
bucket_latents=latents,
|
bucket_latents=latents,
|
||||||
|
use_grad_scaler=use_grad_scaler,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
train_sampler = TrainSampler(
|
train_sampler = TrainSampler(
|
||||||
@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
|
|||||||
seed=seed,
|
seed=seed,
|
||||||
training_dtype=dtype,
|
training_dtype=dtype,
|
||||||
real_dataset=latents if multi_res else None,
|
real_dataset=latents if multi_res else None,
|
||||||
|
use_grad_scaler=use_grad_scaler,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Setup guider
|
# Setup guider
|
||||||
@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
|
|||||||
io.Int.Input(
|
io.Int.Input(
|
||||||
"steps",
|
"steps",
|
||||||
optional=True,
|
optional=True,
|
||||||
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.",
|
tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
|
||||||
),
|
),
|
||||||
],
|
],
|
||||||
outputs=[],
|
outputs=[],
|
||||||
|
|||||||
@ -1,3 +1,3 @@
|
|||||||
# This file is automatically generated by the build process when version is
|
# This file is automatically generated by the build process when version is
|
||||||
# updated in pyproject.toml.
|
# updated in pyproject.toml.
|
||||||
__version__ = "0.16.4"
|
__version__ = "0.18.0"
|
||||||
|
|||||||
147
execution.py
147
execution.py
@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
|
|||||||
from comfy_execution.utils import CurrentNodeContext
|
from comfy_execution.utils import CurrentNodeContext
|
||||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||||
from comfy_api.latest import io, _io
|
from comfy_api.latest import io, _io
|
||||||
|
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
||||||
|
|
||||||
|
|
||||||
class ExecutionResult(Enum):
|
class ExecutionResult(Enum):
|
||||||
@ -126,15 +127,15 @@ class CacheSet:
|
|||||||
|
|
||||||
# Performs like the old cache -- dump data ASAP
|
# Performs like the old cache -- dump data ASAP
|
||||||
def init_classic_cache(self):
|
def init_classic_cache(self):
|
||||||
self.outputs = HierarchicalCache(CacheKeySetInputSignature)
|
self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_lru_cache(self, cache_size):
|
def init_lru_cache(self, cache_size):
|
||||||
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size)
|
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_ram_cache(self, min_headroom):
|
def init_ram_cache(self, min_headroom):
|
||||||
self.outputs = RAMPressureCache(CacheKeySetInputSignature)
|
self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True)
|
||||||
self.objects = HierarchicalCache(CacheKeySetID)
|
self.objects = HierarchicalCache(CacheKeySetID)
|
||||||
|
|
||||||
def init_null_cache(self):
|
def init_null_cache(self):
|
||||||
@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
inputs = dynprompt.get_node(unique_id)['inputs']
|
inputs = dynprompt.get_node(unique_id)['inputs']
|
||||||
class_type = dynprompt.get_node(unique_id)['class_type']
|
class_type = dynprompt.get_node(unique_id)['class_type']
|
||||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||||
cached = caches.outputs.get(unique_id)
|
cached = await caches.outputs.get(unique_id)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
if server.client_id is not None:
|
if server.client_id is not None:
|
||||||
cached_ui = cached.ui or {}
|
cached_ui = cached.ui or {}
|
||||||
@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
server.last_node_id = display_node_id
|
server.last_node_id = display_node_id
|
||||||
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
|
||||||
|
|
||||||
obj = caches.objects.get(unique_id)
|
obj = await caches.objects.get(unique_id)
|
||||||
if obj is None:
|
if obj is None:
|
||||||
obj = class_def()
|
obj = class_def()
|
||||||
caches.objects.set(unique_id, obj)
|
await caches.objects.set(unique_id, obj)
|
||||||
|
|
||||||
if issubclass(class_def, _ComfyNodeInternal):
|
if issubclass(class_def, _ComfyNodeInternal):
|
||||||
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
|
||||||
@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
|||||||
|
|
||||||
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
|
||||||
execution_list.cache_update(unique_id, cache_entry)
|
execution_list.cache_update(unique_id, cache_entry)
|
||||||
caches.outputs.set(unique_id, cache_entry)
|
await caches.outputs.set(unique_id, cache_entry)
|
||||||
|
|
||||||
except comfy.model_management.InterruptProcessingException as iex:
|
except comfy.model_management.InterruptProcessingException as iex:
|
||||||
logging.info("Processing interrupted")
|
logging.info("Processing interrupted")
|
||||||
@ -684,6 +685,19 @@ class PromptExecutor:
|
|||||||
}
|
}
|
||||||
self.add_message("execution_error", mes, broadcast=False)
|
self.add_message("execution_error", mes, broadcast=False)
|
||||||
|
|
||||||
|
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
|
||||||
|
if not _has_cache_providers():
|
||||||
|
return
|
||||||
|
|
||||||
|
for provider in _get_cache_providers():
|
||||||
|
try:
|
||||||
|
if event == "start":
|
||||||
|
provider.on_prompt_start(prompt_id)
|
||||||
|
elif event == "end":
|
||||||
|
provider.on_prompt_end(prompt_id)
|
||||||
|
except Exception as e:
|
||||||
|
_cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
|
||||||
|
|
||||||
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
|
||||||
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
|
||||||
|
|
||||||
@ -700,66 +714,75 @@ class PromptExecutor:
|
|||||||
self.status_messages = []
|
self.status_messages = []
|
||||||
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
|
||||||
|
|
||||||
with torch.inference_mode():
|
self._notify_prompt_lifecycle("start", prompt_id)
|
||||||
dynamic_prompt = DynamicPrompt(prompt)
|
|
||||||
reset_progress_state(prompt_id, dynamic_prompt)
|
|
||||||
add_progress_handler(WebUIProgressHandler(self.server))
|
|
||||||
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
|
||||||
for cache in self.caches.all:
|
|
||||||
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
|
||||||
cache.clean_unused()
|
|
||||||
|
|
||||||
cached_nodes = []
|
try:
|
||||||
for node_id in prompt:
|
with torch.inference_mode():
|
||||||
if self.caches.outputs.get(node_id) is not None:
|
dynamic_prompt = DynamicPrompt(prompt)
|
||||||
cached_nodes.append(node_id)
|
reset_progress_state(prompt_id, dynamic_prompt)
|
||||||
|
add_progress_handler(WebUIProgressHandler(self.server))
|
||||||
|
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
|
||||||
|
for cache in self.caches.all:
|
||||||
|
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
|
||||||
|
cache.clean_unused()
|
||||||
|
|
||||||
comfy.model_management.cleanup_models_gc()
|
node_ids = list(prompt.keys())
|
||||||
self.add_message("execution_cached",
|
cache_results = await asyncio.gather(
|
||||||
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
*(self.caches.outputs.get(node_id) for node_id in node_ids)
|
||||||
broadcast=False)
|
)
|
||||||
pending_subgraph_results = {}
|
cached_nodes = [
|
||||||
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
node_id for node_id, result in zip(node_ids, cache_results)
|
||||||
ui_node_outputs = {}
|
if result is not None
|
||||||
executed = set()
|
]
|
||||||
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
|
||||||
current_outputs = self.caches.outputs.all_node_ids()
|
|
||||||
for node_id in list(execute_outputs):
|
|
||||||
execution_list.add_node(node_id)
|
|
||||||
|
|
||||||
while not execution_list.is_empty():
|
comfy.model_management.cleanup_models_gc()
|
||||||
node_id, error, ex = await execution_list.stage_node_execution()
|
self.add_message("execution_cached",
|
||||||
if error is not None:
|
{ "nodes": cached_nodes, "prompt_id": prompt_id},
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
broadcast=False)
|
||||||
break
|
pending_subgraph_results = {}
|
||||||
|
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
|
||||||
|
ui_node_outputs = {}
|
||||||
|
executed = set()
|
||||||
|
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
|
||||||
|
current_outputs = self.caches.outputs.all_node_ids()
|
||||||
|
for node_id in list(execute_outputs):
|
||||||
|
execution_list.add_node(node_id)
|
||||||
|
|
||||||
assert node_id is not None, "Node ID should not be None at this point"
|
while not execution_list.is_empty():
|
||||||
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
node_id, error, ex = await execution_list.stage_node_execution()
|
||||||
self.success = result != ExecutionResult.FAILURE
|
if error is not None:
|
||||||
if result == ExecutionResult.FAILURE:
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
break
|
||||||
break
|
|
||||||
elif result == ExecutionResult.PENDING:
|
|
||||||
execution_list.unstage_node_execution()
|
|
||||||
else: # result == ExecutionResult.SUCCESS:
|
|
||||||
execution_list.complete_node_execution()
|
|
||||||
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
|
||||||
else:
|
|
||||||
# Only execute when the while-loop ends without break
|
|
||||||
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
|
||||||
|
|
||||||
ui_outputs = {}
|
assert node_id is not None, "Node ID should not be None at this point"
|
||||||
meta_outputs = {}
|
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
|
||||||
for node_id, ui_info in ui_node_outputs.items():
|
self.success = result != ExecutionResult.FAILURE
|
||||||
ui_outputs[node_id] = ui_info["output"]
|
if result == ExecutionResult.FAILURE:
|
||||||
meta_outputs[node_id] = ui_info["meta"]
|
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
|
||||||
self.history_result = {
|
break
|
||||||
"outputs": ui_outputs,
|
elif result == ExecutionResult.PENDING:
|
||||||
"meta": meta_outputs,
|
execution_list.unstage_node_execution()
|
||||||
}
|
else: # result == ExecutionResult.SUCCESS:
|
||||||
self.server.last_node_id = None
|
execution_list.complete_node_execution()
|
||||||
if comfy.model_management.DISABLE_SMART_MEMORY:
|
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
|
||||||
comfy.model_management.unload_all_models()
|
else:
|
||||||
|
# Only execute when the while-loop ends without break
|
||||||
|
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
|
||||||
|
|
||||||
|
ui_outputs = {}
|
||||||
|
meta_outputs = {}
|
||||||
|
for node_id, ui_info in ui_node_outputs.items():
|
||||||
|
ui_outputs[node_id] = ui_info["output"]
|
||||||
|
meta_outputs[node_id] = ui_info["meta"]
|
||||||
|
self.history_result = {
|
||||||
|
"outputs": ui_outputs,
|
||||||
|
"meta": meta_outputs,
|
||||||
|
}
|
||||||
|
self.server.last_node_id = None
|
||||||
|
if comfy.model_management.DISABLE_SMART_MEMORY:
|
||||||
|
comfy.model_management.unload_all_models()
|
||||||
|
finally:
|
||||||
|
self._notify_prompt_lifecycle("end", prompt_id)
|
||||||
|
|
||||||
|
|
||||||
async def validate_inputs(prompt_id, prompt, item, validated):
|
async def validate_inputs(prompt_id, prompt, item, validated):
|
||||||
|
|||||||
4
main.py
4
main.py
@ -206,8 +206,8 @@ import hook_breaker_ac10a0
|
|||||||
import comfy.memory_management
|
import comfy.memory_management
|
||||||
import comfy.model_patcher
|
import comfy.model_patcher
|
||||||
|
|
||||||
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl():
|
if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
|
||||||
if comfy.model_management.torch_version_numeric < (2, 8):
|
if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
|
||||||
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
|
||||||
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
|
||||||
if args.verbose == 'DEBUG':
|
if args.verbose == 'DEBUG':
|
||||||
|
|||||||
@ -1 +1 @@
|
|||||||
comfyui_manager==4.1b2
|
comfyui_manager==4.1b6
|
||||||
@ -32,7 +32,7 @@ async def cache_control(
|
|||||||
)
|
)
|
||||||
|
|
||||||
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
|
||||||
response.headers.setdefault("Cache-Control", "no-cache")
|
response.headers.setdefault("Cache-Control", "no-store")
|
||||||
return response
|
return response
|
||||||
|
|
||||||
# Early return for non-image files - no cache headers needed
|
# Early return for non-image files - no cache headers needed
|
||||||
|
|||||||
24
nodes.py
24
nodes.py
@ -81,6 +81,7 @@ class CLIPTextEncode(ComfyNodeABC):
|
|||||||
|
|
||||||
|
|
||||||
class ConditioningCombine:
|
class ConditioningCombine:
|
||||||
|
ESSENTIALS_CATEGORY = "Image Generation"
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
|
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
|
||||||
@ -951,7 +952,7 @@ class UNETLoader:
|
|||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
|
||||||
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],)
|
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
|
||||||
}}
|
}}
|
||||||
RETURN_TYPES = ("MODEL",)
|
RETURN_TYPES = ("MODEL",)
|
||||||
FUNCTION = "load_unet"
|
FUNCTION = "load_unet"
|
||||||
@ -1211,9 +1212,6 @@ class GLIGENTextBoxApply:
|
|||||||
return (c, )
|
return (c, )
|
||||||
|
|
||||||
class EmptyLatentImage:
|
class EmptyLatentImage:
|
||||||
def __init__(self):
|
|
||||||
self.device = comfy.model_management.intermediate_device()
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def INPUT_TYPES(s):
|
def INPUT_TYPES(s):
|
||||||
return {
|
return {
|
||||||
@ -1232,7 +1230,7 @@ class EmptyLatentImage:
|
|||||||
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
|
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1):
|
def generate(self, width, height, batch_size=1):
|
||||||
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device)
|
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||||
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
|
return ({"samples": latent, "downscale_ratio_spacial": 8}, )
|
||||||
|
|
||||||
|
|
||||||
@ -1724,6 +1722,8 @@ class LoadImage:
|
|||||||
output_masks = []
|
output_masks = []
|
||||||
w, h = None, None
|
w, h = None, None
|
||||||
|
|
||||||
|
dtype = comfy.model_management.intermediate_dtype()
|
||||||
|
|
||||||
for i in ImageSequence.Iterator(img):
|
for i in ImageSequence.Iterator(img):
|
||||||
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
i = node_helpers.pillow(ImageOps.exif_transpose, i)
|
||||||
|
|
||||||
@ -1748,8 +1748,8 @@ class LoadImage:
|
|||||||
mask = 1. - torch.from_numpy(mask)
|
mask = 1. - torch.from_numpy(mask)
|
||||||
else:
|
else:
|
||||||
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
|
||||||
output_images.append(image)
|
output_images.append(image.to(dtype=dtype))
|
||||||
output_masks.append(mask.unsqueeze(0))
|
output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
|
||||||
|
|
||||||
if img.format == "MPO":
|
if img.format == "MPO":
|
||||||
break # ignore all frames except the first one for MPO format
|
break # ignore all frames except the first one for MPO format
|
||||||
@ -1779,6 +1779,7 @@ class LoadImage:
|
|||||||
return True
|
return True
|
||||||
|
|
||||||
class LoadImageMask:
|
class LoadImageMask:
|
||||||
|
ESSENTIALS_CATEGORY = "Image Tools"
|
||||||
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
|
||||||
|
|
||||||
_color_channels = ["alpha", "red", "green", "blue"]
|
_color_channels = ["alpha", "red", "green", "blue"]
|
||||||
@ -1887,6 +1888,7 @@ class ImageScale:
|
|||||||
return (s,)
|
return (s,)
|
||||||
|
|
||||||
class ImageScaleBy:
|
class ImageScaleBy:
|
||||||
|
ESSENTIALS_CATEGORY = "Image Tools"
|
||||||
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@ -1964,9 +1966,11 @@ class EmptyImage:
|
|||||||
CATEGORY = "image"
|
CATEGORY = "image"
|
||||||
|
|
||||||
def generate(self, width, height, batch_size=1, color=0):
|
def generate(self, width, height, batch_size=1, color=0):
|
||||||
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF)
|
dtype = comfy.model_management.intermediate_dtype()
|
||||||
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF)
|
device = comfy.model_management.intermediate_device()
|
||||||
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF)
|
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF, device=device, dtype=dtype)
|
||||||
|
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF, device=device, dtype=dtype)
|
||||||
|
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF, device=device, dtype=dtype)
|
||||||
return (torch.cat((r, g, b), dim=-1), )
|
return (torch.cat((r, g, b), dim=-1), )
|
||||||
|
|
||||||
class ImagePadForOutpaint:
|
class ImagePadForOutpaint:
|
||||||
|
|||||||
@ -1,6 +1,6 @@
|
|||||||
[project]
|
[project]
|
||||||
name = "ComfyUI"
|
name = "ComfyUI"
|
||||||
version = "0.16.4"
|
version = "0.18.0"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = { file = "LICENSE" }
|
license = { file = "LICENSE" }
|
||||||
requires-python = ">=3.10"
|
requires-python = ">=3.10"
|
||||||
|
|||||||
@ -1,5 +1,5 @@
|
|||||||
comfyui-frontend-package==1.41.18
|
comfyui-frontend-package==1.41.21
|
||||||
comfyui-workflow-templates==0.9.21
|
comfyui-workflow-templates==0.9.26
|
||||||
comfyui-embedded-docs==0.4.3
|
comfyui-embedded-docs==0.4.3
|
||||||
torch
|
torch
|
||||||
torchsde
|
torchsde
|
||||||
@ -23,7 +23,7 @@ SQLAlchemy
|
|||||||
filelock
|
filelock
|
||||||
av>=14.2.0
|
av>=14.2.0
|
||||||
comfy-kitchen>=0.2.8
|
comfy-kitchen>=0.2.8
|
||||||
comfy-aimdo>=0.2.10
|
comfy-aimdo>=0.2.12
|
||||||
requests
|
requests
|
||||||
simpleeval>=1.0.0
|
simpleeval>=1.0.0
|
||||||
blake3
|
blake3
|
||||||
|
|||||||
81
server.py
81
server.py
@ -35,6 +35,8 @@ from app.frontend_management import FrontendManager, parse_version
|
|||||||
from comfy_api.internal import _ComfyNodeInternal
|
from comfy_api.internal import _ComfyNodeInternal
|
||||||
from app.assets.seeder import asset_seeder
|
from app.assets.seeder import asset_seeder
|
||||||
from app.assets.api.routes import register_assets_routes
|
from app.assets.api.routes import register_assets_routes
|
||||||
|
from app.assets.services.ingest import register_file_in_place
|
||||||
|
from app.assets.services.asset_management import resolve_hash_to_path
|
||||||
|
|
||||||
from app.user_manager import UserManager
|
from app.user_manager import UserManager
|
||||||
from app.model_manager import ModelFileManager
|
from app.model_manager import ModelFileManager
|
||||||
@ -310,7 +312,7 @@ class PromptServer():
|
|||||||
@routes.get("/")
|
@routes.get("/")
|
||||||
async def get_root(request):
|
async def get_root(request):
|
||||||
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
response = web.FileResponse(os.path.join(self.web_root, "index.html"))
|
||||||
response.headers['Cache-Control'] = 'no-cache'
|
response.headers['Cache-Control'] = 'no-store, must-revalidate'
|
||||||
response.headers["Pragma"] = "no-cache"
|
response.headers["Pragma"] = "no-cache"
|
||||||
response.headers["Expires"] = "0"
|
response.headers["Expires"] = "0"
|
||||||
return response
|
return response
|
||||||
@ -419,7 +421,24 @@ class PromptServer():
|
|||||||
with open(filepath, "wb") as f:
|
with open(filepath, "wb") as f:
|
||||||
f.write(image.file.read())
|
f.write(image.file.read())
|
||||||
|
|
||||||
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type})
|
resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
|
||||||
|
|
||||||
|
if args.enable_assets:
|
||||||
|
try:
|
||||||
|
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
|
||||||
|
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
|
||||||
|
resp["asset"] = {
|
||||||
|
"id": result.ref.id,
|
||||||
|
"name": result.ref.name,
|
||||||
|
"asset_hash": result.asset.hash,
|
||||||
|
"size": result.asset.size_bytes,
|
||||||
|
"mime_type": result.asset.mime_type,
|
||||||
|
"tags": result.tags,
|
||||||
|
}
|
||||||
|
except Exception:
|
||||||
|
logging.warning("Failed to register uploaded image as asset", exc_info=True)
|
||||||
|
|
||||||
|
return web.json_response(resp)
|
||||||
else:
|
else:
|
||||||
return web.Response(status=400)
|
return web.Response(status=400)
|
||||||
|
|
||||||
@ -479,30 +498,43 @@ class PromptServer():
|
|||||||
async def view_image(request):
|
async def view_image(request):
|
||||||
if "filename" in request.rel_url.query:
|
if "filename" in request.rel_url.query:
|
||||||
filename = request.rel_url.query["filename"]
|
filename = request.rel_url.query["filename"]
|
||||||
filename, output_dir = folder_paths.annotated_filepath(filename)
|
|
||||||
|
|
||||||
if not filename:
|
# The frontend's LoadImage combo widget uses asset_hash values
|
||||||
return web.Response(status=400)
|
# (e.g. "blake3:...") as widget values. When litegraph renders the
|
||||||
|
# node preview, it constructs /view?filename=<asset_hash>, so this
|
||||||
|
# endpoint must resolve blake3 hashes to their on-disk file paths.
|
||||||
|
if filename.startswith("blake3:"):
|
||||||
|
owner_id = self.user_manager.get_request_user_id(request)
|
||||||
|
result = resolve_hash_to_path(filename, owner_id=owner_id)
|
||||||
|
if result is None:
|
||||||
|
return web.Response(status=404)
|
||||||
|
file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type
|
||||||
|
else:
|
||||||
|
resolved_content_type = None
|
||||||
|
filename, output_dir = folder_paths.annotated_filepath(filename)
|
||||||
|
|
||||||
# validation for security: prevent accessing arbitrary path
|
if not filename:
|
||||||
if filename[0] == '/' or '..' in filename:
|
return web.Response(status=400)
|
||||||
return web.Response(status=400)
|
|
||||||
|
|
||||||
if output_dir is None:
|
# validation for security: prevent accessing arbitrary path
|
||||||
type = request.rel_url.query.get("type", "output")
|
if filename[0] == '/' or '..' in filename:
|
||||||
output_dir = folder_paths.get_directory_by_type(type)
|
return web.Response(status=400)
|
||||||
|
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
return web.Response(status=400)
|
type = request.rel_url.query.get("type", "output")
|
||||||
|
output_dir = folder_paths.get_directory_by_type(type)
|
||||||
|
|
||||||
if "subfolder" in request.rel_url.query:
|
if output_dir is None:
|
||||||
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
return web.Response(status=400)
|
||||||
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
|
||||||
return web.Response(status=403)
|
|
||||||
output_dir = full_output_dir
|
|
||||||
|
|
||||||
filename = os.path.basename(filename)
|
if "subfolder" in request.rel_url.query:
|
||||||
file = os.path.join(output_dir, filename)
|
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
|
||||||
|
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
|
||||||
|
return web.Response(status=403)
|
||||||
|
output_dir = full_output_dir
|
||||||
|
|
||||||
|
filename = os.path.basename(filename)
|
||||||
|
file = os.path.join(output_dir, filename)
|
||||||
|
|
||||||
if os.path.isfile(file):
|
if os.path.isfile(file):
|
||||||
if 'preview' in request.rel_url.query:
|
if 'preview' in request.rel_url.query:
|
||||||
@ -562,8 +594,13 @@ class PromptServer():
|
|||||||
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
return web.Response(body=alpha_buffer.read(), content_type='image/png',
|
||||||
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
headers={"Content-Disposition": f"filename=\"{filename}\""})
|
||||||
else:
|
else:
|
||||||
# Get content type from mimetype, defaulting to 'application/octet-stream'
|
# Use the content type from asset resolution if available,
|
||||||
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream'
|
# otherwise guess from the filename.
|
||||||
|
content_type = (
|
||||||
|
resolved_content_type
|
||||||
|
or mimetypes.guess_type(filename)[0]
|
||||||
|
or 'application/octet-stream'
|
||||||
|
)
|
||||||
|
|
||||||
# For security, force certain mimetypes to download instead of display
|
# For security, force certain mimetypes to download instead of display
|
||||||
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:
|
||||||
|
|||||||
57
tests-unit/app_test/test_migrations.py
Normal file
57
tests-unit/app_test/test_migrations.py
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
"""Test that Alembic migrations run cleanly on a file-backed SQLite DB.
|
||||||
|
|
||||||
|
This catches problems like unnamed FK constraints that prevent batch-mode
|
||||||
|
drop_constraint from working on real SQLite files (see MB-2).
|
||||||
|
|
||||||
|
Migrations 0001 and 0002 are already shipped, so we only exercise
|
||||||
|
upgrade/downgrade for 0003+.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from alembic import command
|
||||||
|
from alembic.config import Config
|
||||||
|
|
||||||
|
|
||||||
|
# Oldest shipped revision — we upgrade to here as a baseline and never
|
||||||
|
# downgrade past it.
|
||||||
|
_BASELINE = "0002_merge_to_asset_references"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_config(db_path: str) -> Config:
|
||||||
|
root = os.path.join(os.path.dirname(__file__), "../..")
|
||||||
|
config_path = os.path.abspath(os.path.join(root, "alembic.ini"))
|
||||||
|
scripts_path = os.path.abspath(os.path.join(root, "alembic_db"))
|
||||||
|
|
||||||
|
cfg = Config(config_path)
|
||||||
|
cfg.set_main_option("script_location", scripts_path)
|
||||||
|
cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
|
||||||
|
return cfg
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def migration_db(tmp_path):
|
||||||
|
"""Yield an alembic Config pre-upgraded to the baseline revision."""
|
||||||
|
db_path = str(tmp_path / "test_migration.db")
|
||||||
|
cfg = _make_config(db_path)
|
||||||
|
command.upgrade(cfg, _BASELINE)
|
||||||
|
yield cfg
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade_to_head(migration_db):
|
||||||
|
"""Upgrade from baseline to head must succeed on a file-backed DB."""
|
||||||
|
command.upgrade(migration_db, "head")
|
||||||
|
|
||||||
|
|
||||||
|
def test_downgrade_to_baseline(migration_db):
|
||||||
|
"""Upgrade to head then downgrade back to baseline."""
|
||||||
|
command.upgrade(migration_db, "head")
|
||||||
|
command.downgrade(migration_db, _BASELINE)
|
||||||
|
|
||||||
|
|
||||||
|
def test_upgrade_downgrade_cycle(migration_db):
|
||||||
|
"""Full cycle: upgrade → downgrade → upgrade again."""
|
||||||
|
command.upgrade(migration_db, "head")
|
||||||
|
command.downgrade(migration_db, _BASELINE)
|
||||||
|
command.upgrade(migration_db, "head")
|
||||||
@ -10,6 +10,7 @@ from app.assets.database.queries import (
|
|||||||
get_asset_by_hash,
|
get_asset_by_hash,
|
||||||
upsert_asset,
|
upsert_asset,
|
||||||
bulk_insert_assets,
|
bulk_insert_assets,
|
||||||
|
update_asset_hash_and_mime,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@ -142,3 +143,45 @@ class TestBulkInsertAssets:
|
|||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
assert session.query(Asset).count() == 200
|
assert session.query(Asset).count() == 200
|
||||||
|
|
||||||
|
|
||||||
|
class TestMimeTypeImmutability:
|
||||||
|
"""mime_type on Asset is write-once: set on first ingest, never overwritten."""
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"initial_mime,second_mime,expected_mime",
|
||||||
|
[
|
||||||
|
("image/png", "image/jpeg", "image/png"),
|
||||||
|
(None, "image/png", "image/png"),
|
||||||
|
],
|
||||||
|
ids=["preserves_existing", "fills_null"],
|
||||||
|
)
|
||||||
|
def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime):
|
||||||
|
h = f"blake3:upsert_{initial_mime}_{second_mime}"
|
||||||
|
upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime)
|
||||||
|
assert created is False
|
||||||
|
assert asset.mime_type == expected_mime
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"initial_mime,update_mime,update_hash,expected_mime,expected_hash",
|
||||||
|
[
|
||||||
|
(None, "image/png", None, "image/png", "blake3:upd0"),
|
||||||
|
("image/png", "image/jpeg", None, "image/png", "blake3:upd1"),
|
||||||
|
("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"),
|
||||||
|
],
|
||||||
|
ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"],
|
||||||
|
)
|
||||||
|
def test_update_asset_hash_and_mime_immutability(
|
||||||
|
self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash,
|
||||||
|
):
|
||||||
|
h = expected_hash.removesuffix("_new")
|
||||||
|
asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime)
|
||||||
|
session.add(asset)
|
||||||
|
session.flush()
|
||||||
|
|
||||||
|
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash)
|
||||||
|
assert asset.mime_type == expected_mime
|
||||||
|
assert asset.hash == expected_hash
|
||||||
|
|||||||
@ -242,22 +242,24 @@ class TestSetReferencePreview:
|
|||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
preview_asset = _make_asset(session, "preview_hash")
|
preview_asset = _make_asset(session, "preview_hash")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
|
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id)
|
set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
session.refresh(ref)
|
session.refresh(ref)
|
||||||
assert ref.preview_id == preview_asset.id
|
assert ref.preview_id == preview_ref.id
|
||||||
|
|
||||||
def test_clears_preview(self, session: Session):
|
def test_clears_preview(self, session: Session):
|
||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
preview_asset = _make_asset(session, "preview_hash")
|
preview_asset = _make_asset(session, "preview_hash")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
ref.preview_id = preview_asset.id
|
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||||
|
ref.preview_id = preview_ref.id
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id=None)
|
set_reference_preview(session, reference_id=ref.id, preview_reference_id=None)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
session.refresh(ref)
|
session.refresh(ref)
|
||||||
@ -265,15 +267,15 @@ class TestSetReferencePreview:
|
|||||||
|
|
||||||
def test_raises_for_nonexistent_reference(self, session: Session):
|
def test_raises_for_nonexistent_reference(self, session: Session):
|
||||||
with pytest.raises(ValueError, match="not found"):
|
with pytest.raises(ValueError, match="not found"):
|
||||||
set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None)
|
set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None)
|
||||||
|
|
||||||
def test_raises_for_nonexistent_preview(self, session: Session):
|
def test_raises_for_nonexistent_preview(self, session: Session):
|
||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="Preview Asset"):
|
with pytest.raises(ValueError, match="Preview AssetReference"):
|
||||||
set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent")
|
set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent")
|
||||||
|
|
||||||
|
|
||||||
class TestInsertReference:
|
class TestInsertReference:
|
||||||
@ -351,13 +353,14 @@ class TestUpdateReferenceTimestamps:
|
|||||||
asset = _make_asset(session, "hash1")
|
asset = _make_asset(session, "hash1")
|
||||||
preview_asset = _make_asset(session, "preview_hash")
|
preview_asset = _make_asset(session, "preview_hash")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
|
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
update_reference_timestamps(session, ref, preview_id=preview_asset.id)
|
update_reference_timestamps(session, ref, preview_id=preview_ref.id)
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
session.refresh(ref)
|
session.refresh(ref)
|
||||||
assert ref.preview_id == preview_asset.id
|
assert ref.preview_id == preview_ref.id
|
||||||
|
|
||||||
|
|
||||||
class TestSetReferenceMetadata:
|
class TestSetReferenceMetadata:
|
||||||
|
|||||||
@ -20,6 +20,7 @@ def _make_reference(
|
|||||||
asset: Asset,
|
asset: Asset,
|
||||||
name: str,
|
name: str,
|
||||||
metadata: dict | None = None,
|
metadata: dict | None = None,
|
||||||
|
system_metadata: dict | None = None,
|
||||||
) -> AssetReference:
|
) -> AssetReference:
|
||||||
now = get_utc_now()
|
now = get_utc_now()
|
||||||
ref = AssetReference(
|
ref = AssetReference(
|
||||||
@ -27,6 +28,7 @@ def _make_reference(
|
|||||||
name=name,
|
name=name,
|
||||||
asset_id=asset.id,
|
asset_id=asset.id,
|
||||||
user_metadata=metadata,
|
user_metadata=metadata,
|
||||||
|
system_metadata=system_metadata,
|
||||||
created_at=now,
|
created_at=now,
|
||||||
updated_at=now,
|
updated_at=now,
|
||||||
last_access_time=now,
|
last_access_time=now,
|
||||||
@ -34,8 +36,10 @@ def _make_reference(
|
|||||||
session.add(ref)
|
session.add(ref)
|
||||||
session.flush()
|
session.flush()
|
||||||
|
|
||||||
if metadata:
|
# Build merged projection: {**system_metadata, **user_metadata}
|
||||||
for key, val in metadata.items():
|
merged = {**(system_metadata or {}), **(metadata or {})}
|
||||||
|
if merged:
|
||||||
|
for key, val in merged.items():
|
||||||
for row in convert_metadata_to_rows(key, val):
|
for row in convert_metadata_to_rows(key, val):
|
||||||
meta_row = AssetReferenceMeta(
|
meta_row = AssetReferenceMeta(
|
||||||
asset_reference_id=ref.id,
|
asset_reference_id=ref.id,
|
||||||
@ -182,3 +186,46 @@ class TestMetadataFilterEmptyDict:
|
|||||||
|
|
||||||
refs, _, total = list_references_page(session, metadata_filter={})
|
refs, _, total = list_references_page(session, metadata_filter={})
|
||||||
assert total == 2
|
assert total == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestSystemMetadataProjection:
|
||||||
|
"""Tests for system_metadata merging into the filter projection."""
|
||||||
|
|
||||||
|
def test_system_metadata_keys_are_filterable(self, session: Session):
|
||||||
|
"""system_metadata keys should appear in the merged projection."""
|
||||||
|
asset = _make_asset(session, "hash1")
|
||||||
|
_make_reference(
|
||||||
|
session, asset, "with_sys",
|
||||||
|
system_metadata={"source": "scanner"},
|
||||||
|
)
|
||||||
|
_make_reference(session, asset, "without_sys")
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
refs, _, total = list_references_page(
|
||||||
|
session, metadata_filter={"source": "scanner"}
|
||||||
|
)
|
||||||
|
assert total == 1
|
||||||
|
assert refs[0].name == "with_sys"
|
||||||
|
|
||||||
|
def test_user_metadata_overrides_system_metadata(self, session: Session):
|
||||||
|
"""user_metadata should win when both have the same key."""
|
||||||
|
asset = _make_asset(session, "hash1")
|
||||||
|
_make_reference(
|
||||||
|
session, asset, "overridden",
|
||||||
|
metadata={"origin": "user_upload"},
|
||||||
|
system_metadata={"origin": "auto_scan"},
|
||||||
|
)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
# Should match the user value, not the system value
|
||||||
|
refs, _, total = list_references_page(
|
||||||
|
session, metadata_filter={"origin": "user_upload"}
|
||||||
|
)
|
||||||
|
assert total == 1
|
||||||
|
assert refs[0].name == "overridden"
|
||||||
|
|
||||||
|
# Should NOT match the system value (it was overridden)
|
||||||
|
refs, _, total = list_references_page(
|
||||||
|
session, metadata_filter={"origin": "auto_scan"}
|
||||||
|
)
|
||||||
|
assert total == 0
|
||||||
|
|||||||
@ -11,6 +11,7 @@ from app.assets.services import (
|
|||||||
delete_asset_reference,
|
delete_asset_reference,
|
||||||
set_asset_preview,
|
set_asset_preview,
|
||||||
)
|
)
|
||||||
|
from app.assets.services.asset_management import resolve_hash_to_path
|
||||||
|
|
||||||
|
|
||||||
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
|
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
|
||||||
@ -219,31 +220,33 @@ class TestSetAssetPreview:
|
|||||||
asset = _make_asset(session, hash_val="blake3:main")
|
asset = _make_asset(session, hash_val="blake3:main")
|
||||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
|
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||||
ref_id = ref.id
|
ref_id = ref.id
|
||||||
preview_id = preview_asset.id
|
preview_ref_id = preview_ref.id
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_asset_preview(
|
set_asset_preview(
|
||||||
reference_id=ref_id,
|
reference_id=ref_id,
|
||||||
preview_asset_id=preview_id,
|
preview_reference_id=preview_ref_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify by re-fetching from DB
|
# Verify by re-fetching from DB
|
||||||
session.expire_all()
|
session.expire_all()
|
||||||
updated_ref = session.get(AssetReference, ref_id)
|
updated_ref = session.get(AssetReference, ref_id)
|
||||||
assert updated_ref.preview_id == preview_id
|
assert updated_ref.preview_id == preview_ref_id
|
||||||
|
|
||||||
def test_clears_preview(self, mock_create_session, session: Session):
|
def test_clears_preview(self, mock_create_session, session: Session):
|
||||||
asset = _make_asset(session)
|
asset = _make_asset(session)
|
||||||
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
preview_asset = _make_asset(session, hash_val="blake3:preview")
|
||||||
ref = _make_reference(session, asset)
|
ref = _make_reference(session, asset)
|
||||||
ref.preview_id = preview_asset.id
|
preview_ref = _make_reference(session, preview_asset, name="preview.png")
|
||||||
|
ref.preview_id = preview_ref.id
|
||||||
ref_id = ref.id
|
ref_id = ref.id
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
set_asset_preview(
|
set_asset_preview(
|
||||||
reference_id=ref_id,
|
reference_id=ref_id,
|
||||||
preview_asset_id=None,
|
preview_reference_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verify by re-fetching from DB
|
# Verify by re-fetching from DB
|
||||||
@ -263,6 +266,45 @@ class TestSetAssetPreview:
|
|||||||
with pytest.raises(PermissionError, match="not owner"):
|
with pytest.raises(PermissionError, match="not owner"):
|
||||||
set_asset_preview(
|
set_asset_preview(
|
||||||
reference_id=ref.id,
|
reference_id=ref.id,
|
||||||
preview_asset_id=None,
|
preview_reference_id=None,
|
||||||
owner_id="user2",
|
owner_id="user2",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestResolveHashToPath:
|
||||||
|
def test_returns_none_for_unknown_hash(self, mock_create_session):
|
||||||
|
result = resolve_hash_to_path("blake3:" + "a" * 64)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"ref_owner, query_owner, expect_found",
|
||||||
|
[
|
||||||
|
("user1", "user1", True),
|
||||||
|
("user1", "user2", False),
|
||||||
|
("", "anyone", True),
|
||||||
|
("", "", True),
|
||||||
|
],
|
||||||
|
ids=[
|
||||||
|
"owner_sees_own_ref",
|
||||||
|
"other_owner_blocked",
|
||||||
|
"ownerless_visible_to_anyone",
|
||||||
|
"ownerless_visible_to_empty",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_owner_visibility(
|
||||||
|
self, ref_owner, query_owner, expect_found,
|
||||||
|
mock_create_session, session: Session, temp_dir,
|
||||||
|
):
|
||||||
|
f = temp_dir / "file.bin"
|
||||||
|
f.write_bytes(b"data")
|
||||||
|
asset = _make_asset(session, hash_val="blake3:" + "b" * 64)
|
||||||
|
ref = _make_reference(session, asset, name="file.bin", owner_id=ref_owner)
|
||||||
|
ref.file_path = str(f)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = resolve_hash_to_path(asset.hash, owner_id=query_owner)
|
||||||
|
if expect_found:
|
||||||
|
assert result is not None
|
||||||
|
assert result.abs_path == str(f)
|
||||||
|
else:
|
||||||
|
assert result is None
|
||||||
|
|||||||
@ -113,11 +113,19 @@ class TestIngestFileFromPath:
|
|||||||
file_path = temp_dir / "with_preview.bin"
|
file_path = temp_dir / "with_preview.bin"
|
||||||
file_path.write_bytes(b"data")
|
file_path.write_bytes(b"data")
|
||||||
|
|
||||||
# Create a preview asset first
|
# Create a preview asset and reference
|
||||||
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
|
preview_asset = Asset(hash="blake3:preview", size_bytes=100)
|
||||||
session.add(preview_asset)
|
session.add(preview_asset)
|
||||||
|
session.flush()
|
||||||
|
from app.assets.helpers import get_utc_now
|
||||||
|
now = get_utc_now()
|
||||||
|
preview_ref = AssetReference(
|
||||||
|
asset_id=preview_asset.id, name="preview.png", owner_id="",
|
||||||
|
created_at=now, updated_at=now, last_access_time=now,
|
||||||
|
)
|
||||||
|
session.add(preview_ref)
|
||||||
session.commit()
|
session.commit()
|
||||||
preview_id = preview_asset.id
|
preview_id = preview_ref.id
|
||||||
|
|
||||||
result = _ingest_file_from_path(
|
result = _ingest_file_from_path(
|
||||||
abs_path=str(file_path),
|
abs_path=str(file_path),
|
||||||
|
|||||||
123
tests-unit/assets_test/services/test_tag_histogram.py
Normal file
123
tests-unit/assets_test/services/test_tag_histogram.py
Normal file
@ -0,0 +1,123 @@
|
|||||||
|
"""Tests for list_tag_histogram service function."""
|
||||||
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
from app.assets.database.models import Asset, AssetReference
|
||||||
|
from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference
|
||||||
|
from app.assets.helpers import get_utc_now
|
||||||
|
from app.assets.services.tagging import list_tag_histogram
|
||||||
|
|
||||||
|
|
||||||
|
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
|
||||||
|
asset = Asset(hash=hash_val, size_bytes=1024)
|
||||||
|
session.add(asset)
|
||||||
|
session.flush()
|
||||||
|
return asset
|
||||||
|
|
||||||
|
|
||||||
|
def _make_reference(
|
||||||
|
session: Session,
|
||||||
|
asset: Asset,
|
||||||
|
name: str = "test",
|
||||||
|
owner_id: str = "",
|
||||||
|
) -> AssetReference:
|
||||||
|
now = get_utc_now()
|
||||||
|
ref = AssetReference(
|
||||||
|
owner_id=owner_id,
|
||||||
|
name=name,
|
||||||
|
asset_id=asset.id,
|
||||||
|
created_at=now,
|
||||||
|
updated_at=now,
|
||||||
|
last_access_time=now,
|
||||||
|
)
|
||||||
|
session.add(ref)
|
||||||
|
session.flush()
|
||||||
|
return ref
|
||||||
|
|
||||||
|
|
||||||
|
class TestListTagHistogram:
|
||||||
|
def test_returns_counts_for_all_tags(self, mock_create_session, session: Session):
|
||||||
|
ensure_tags_exist(session, ["alpha", "beta"])
|
||||||
|
a1 = _make_asset(session, "blake3:aaa")
|
||||||
|
r1 = _make_reference(session, a1, name="r1")
|
||||||
|
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha", "beta"])
|
||||||
|
|
||||||
|
a2 = _make_asset(session, "blake3:bbb")
|
||||||
|
r2 = _make_reference(session, a2, name="r2")
|
||||||
|
add_tags_to_reference(session, reference_id=r2.id, tags=["alpha"])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram()
|
||||||
|
|
||||||
|
assert result["alpha"] == 2
|
||||||
|
assert result["beta"] == 1
|
||||||
|
|
||||||
|
def test_empty_when_no_assets(self, mock_create_session, session: Session):
|
||||||
|
ensure_tags_exist(session, ["unused"])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram()
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
def test_include_tags_filter(self, mock_create_session, session: Session):
|
||||||
|
ensure_tags_exist(session, ["models", "loras", "input"])
|
||||||
|
a1 = _make_asset(session, "blake3:aaa")
|
||||||
|
r1 = _make_reference(session, a1, name="r1")
|
||||||
|
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
|
||||||
|
|
||||||
|
a2 = _make_asset(session, "blake3:bbb")
|
||||||
|
r2 = _make_reference(session, a2, name="r2")
|
||||||
|
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram(include_tags=["models"])
|
||||||
|
|
||||||
|
# Only r1 has "models", so only its tags appear
|
||||||
|
assert "models" in result
|
||||||
|
assert "loras" in result
|
||||||
|
assert "input" not in result
|
||||||
|
|
||||||
|
def test_exclude_tags_filter(self, mock_create_session, session: Session):
|
||||||
|
ensure_tags_exist(session, ["models", "loras", "input"])
|
||||||
|
a1 = _make_asset(session, "blake3:aaa")
|
||||||
|
r1 = _make_reference(session, a1, name="r1")
|
||||||
|
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
|
||||||
|
|
||||||
|
a2 = _make_asset(session, "blake3:bbb")
|
||||||
|
r2 = _make_reference(session, a2, name="r2")
|
||||||
|
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram(exclude_tags=["models"])
|
||||||
|
|
||||||
|
# r1 excluded, only r2's tags remain
|
||||||
|
assert "input" in result
|
||||||
|
assert "loras" not in result
|
||||||
|
|
||||||
|
def test_name_contains_filter(self, mock_create_session, session: Session):
|
||||||
|
ensure_tags_exist(session, ["alpha", "beta"])
|
||||||
|
a1 = _make_asset(session, "blake3:aaa")
|
||||||
|
r1 = _make_reference(session, a1, name="my_model.safetensors")
|
||||||
|
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha"])
|
||||||
|
|
||||||
|
a2 = _make_asset(session, "blake3:bbb")
|
||||||
|
r2 = _make_reference(session, a2, name="picture.png")
|
||||||
|
add_tags_to_reference(session, reference_id=r2.id, tags=["beta"])
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram(name_contains="model")
|
||||||
|
|
||||||
|
assert "alpha" in result
|
||||||
|
assert "beta" not in result
|
||||||
|
|
||||||
|
def test_limit_caps_results(self, mock_create_session, session: Session):
|
||||||
|
tags = [f"tag{i}" for i in range(10)]
|
||||||
|
ensure_tags_exist(session, tags)
|
||||||
|
a = _make_asset(session, "blake3:aaa")
|
||||||
|
r = _make_reference(session, a, name="r1")
|
||||||
|
add_tags_to_reference(session, reference_id=r.id, tags=tags)
|
||||||
|
session.commit()
|
||||||
|
|
||||||
|
result = list_tag_histogram(limit=3)
|
||||||
|
|
||||||
|
assert len(result) == 3
|
||||||
@ -243,6 +243,15 @@ def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
|
|||||||
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
|
||||||
|
|
||||||
|
|
||||||
|
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):
|
||||||
|
files = {"file": ("notags.bin", b"A" * 64, "application/octet-stream")}
|
||||||
|
form = {"tags": json.dumps([]), "name": "notags.bin", "user_metadata": json.dumps({})}
|
||||||
|
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
|
||||||
|
body = r.json()
|
||||||
|
assert r.status_code == 400
|
||||||
|
assert body["error"]["code"] == "INVALID_BODY"
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("root", ["input", "output"])
|
@pytest.mark.parametrize("root", ["input", "output"])
|
||||||
def test_duplicate_upload_same_display_name_does_not_clobber(
|
def test_duplicate_upload_same_display_name_does_not_clobber(
|
||||||
root: str,
|
root: str,
|
||||||
|
|||||||
403
tests-unit/execution_test/test_cache_provider.py
Normal file
403
tests-unit/execution_test/test_cache_provider.py
Normal file
@ -0,0 +1,403 @@
|
|||||||
|
"""Tests for external cache provider API."""
|
||||||
|
|
||||||
|
import importlib.util
|
||||||
|
import pytest
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
|
||||||
|
def _torch_available() -> bool:
|
||||||
|
"""Check if PyTorch is available."""
|
||||||
|
return importlib.util.find_spec("torch") is not None
|
||||||
|
|
||||||
|
|
||||||
|
from comfy_execution.cache_provider import (
|
||||||
|
CacheProvider,
|
||||||
|
CacheContext,
|
||||||
|
CacheValue,
|
||||||
|
register_cache_provider,
|
||||||
|
unregister_cache_provider,
|
||||||
|
_get_cache_providers,
|
||||||
|
_has_cache_providers,
|
||||||
|
_clear_cache_providers,
|
||||||
|
_serialize_cache_key,
|
||||||
|
_contains_self_unequal,
|
||||||
|
_estimate_value_size,
|
||||||
|
_canonicalize,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TestCanonicalize:
|
||||||
|
"""Test _canonicalize function for deterministic ordering."""
|
||||||
|
|
||||||
|
def test_frozenset_ordering_is_deterministic(self):
|
||||||
|
"""Frozensets should produce consistent canonical form regardless of iteration order."""
|
||||||
|
# Create two frozensets with same content
|
||||||
|
fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
|
||||||
|
fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
|
||||||
|
|
||||||
|
result1 = _canonicalize(fs1)
|
||||||
|
result2 = _canonicalize(fs2)
|
||||||
|
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
def test_nested_frozenset_ordering(self):
|
||||||
|
"""Nested frozensets should also be deterministically ordered."""
|
||||||
|
inner1 = frozenset([1, 2, 3])
|
||||||
|
inner2 = frozenset([3, 2, 1])
|
||||||
|
|
||||||
|
fs1 = frozenset([("key", inner1)])
|
||||||
|
fs2 = frozenset([("key", inner2)])
|
||||||
|
|
||||||
|
result1 = _canonicalize(fs1)
|
||||||
|
result2 = _canonicalize(fs2)
|
||||||
|
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
def test_dict_ordering(self):
|
||||||
|
"""Dicts should be sorted by key."""
|
||||||
|
d1 = {"z": 1, "a": 2, "m": 3}
|
||||||
|
d2 = {"a": 2, "m": 3, "z": 1}
|
||||||
|
|
||||||
|
result1 = _canonicalize(d1)
|
||||||
|
result2 = _canonicalize(d2)
|
||||||
|
|
||||||
|
assert result1 == result2
|
||||||
|
|
||||||
|
def test_tuple_preserved(self):
|
||||||
|
"""Tuples should be marked and preserved."""
|
||||||
|
t = (1, 2, 3)
|
||||||
|
result = _canonicalize(t)
|
||||||
|
|
||||||
|
assert result[0] == "__tuple__"
|
||||||
|
|
||||||
|
def test_list_preserved(self):
|
||||||
|
"""Lists should be recursively canonicalized."""
|
||||||
|
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
|
||||||
|
result = _canonicalize(lst)
|
||||||
|
|
||||||
|
# First element should be canonicalized dict
|
||||||
|
assert "__dict__" in result[0]
|
||||||
|
# Second element should be canonicalized frozenset
|
||||||
|
assert result[1][0] == "__frozenset__"
|
||||||
|
|
||||||
|
def test_primitives_include_type(self):
|
||||||
|
"""Primitive types should include type name for disambiguation."""
|
||||||
|
assert _canonicalize(42) == ("int", 42)
|
||||||
|
assert _canonicalize(3.14) == ("float", 3.14)
|
||||||
|
assert _canonicalize("hello") == ("str", "hello")
|
||||||
|
assert _canonicalize(True) == ("bool", True)
|
||||||
|
assert _canonicalize(None) == ("NoneType", None)
|
||||||
|
|
||||||
|
def test_int_and_str_distinguished(self):
|
||||||
|
"""int 7 and str '7' must produce different canonical forms."""
|
||||||
|
assert _canonicalize(7) != _canonicalize("7")
|
||||||
|
|
||||||
|
def test_bytes_converted(self):
|
||||||
|
"""Bytes should be converted to hex string."""
|
||||||
|
b = b"\x00\xff"
|
||||||
|
result = _canonicalize(b)
|
||||||
|
|
||||||
|
assert result[0] == "__bytes__"
|
||||||
|
assert result[1] == "00ff"
|
||||||
|
|
||||||
|
def test_set_ordering(self):
|
||||||
|
"""Sets should be sorted like frozensets."""
|
||||||
|
s1 = {3, 1, 2}
|
||||||
|
s2 = {1, 2, 3}
|
||||||
|
|
||||||
|
result1 = _canonicalize(s1)
|
||||||
|
result2 = _canonicalize(s2)
|
||||||
|
|
||||||
|
assert result1 == result2
|
||||||
|
assert result1[0] == "__set__"
|
||||||
|
|
||||||
|
def test_unknown_type_raises(self):
|
||||||
|
"""Unknown types should raise ValueError (fail-closed)."""
|
||||||
|
class CustomObj:
|
||||||
|
pass
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_canonicalize(CustomObj())
|
||||||
|
|
||||||
|
def test_object_with_value_attr_raises(self):
|
||||||
|
"""Objects with .value attribute (Unhashable-like) should raise ValueError."""
|
||||||
|
class FakeUnhashable:
|
||||||
|
def __init__(self):
|
||||||
|
self.value = float('nan')
|
||||||
|
with pytest.raises(ValueError):
|
||||||
|
_canonicalize(FakeUnhashable())
|
||||||
|
|
||||||
|
|
||||||
|
class TestSerializeCacheKey:
|
||||||
|
"""Test _serialize_cache_key for deterministic hashing."""
|
||||||
|
|
||||||
|
def test_same_content_same_hash(self):
|
||||||
|
"""Same content should produce same hash."""
|
||||||
|
key1 = frozenset([("node_1", frozenset([("input", "value")]))])
|
||||||
|
key2 = frozenset([("node_1", frozenset([("input", "value")]))])
|
||||||
|
|
||||||
|
hash1 = _serialize_cache_key(key1)
|
||||||
|
hash2 = _serialize_cache_key(key2)
|
||||||
|
|
||||||
|
assert hash1 == hash2
|
||||||
|
|
||||||
|
def test_different_content_different_hash(self):
|
||||||
|
"""Different content should produce different hash."""
|
||||||
|
key1 = frozenset([("node_1", "value_a")])
|
||||||
|
key2 = frozenset([("node_1", "value_b")])
|
||||||
|
|
||||||
|
hash1 = _serialize_cache_key(key1)
|
||||||
|
hash2 = _serialize_cache_key(key2)
|
||||||
|
|
||||||
|
assert hash1 != hash2
|
||||||
|
|
||||||
|
def test_returns_hex_string(self):
|
||||||
|
"""Should return hex string (SHA256 hex digest)."""
|
||||||
|
key = frozenset([("test", 123)])
|
||||||
|
result = _serialize_cache_key(key)
|
||||||
|
|
||||||
|
assert isinstance(result, str)
|
||||||
|
assert len(result) == 64 # SHA256 hex digest is 64 chars
|
||||||
|
|
||||||
|
def test_complex_nested_structure(self):
|
||||||
|
"""Complex nested structures should hash deterministically."""
|
||||||
|
# Note: frozensets can only contain hashable types, so we use
|
||||||
|
# nested frozensets of tuples to represent dict-like structures
|
||||||
|
key = frozenset([
|
||||||
|
("node_1", frozenset([
|
||||||
|
("input_a", ("tuple", "value")),
|
||||||
|
("input_b", frozenset([("nested", "dict")])),
|
||||||
|
])),
|
||||||
|
("node_2", frozenset([
|
||||||
|
("param", 42),
|
||||||
|
])),
|
||||||
|
])
|
||||||
|
|
||||||
|
# Hash twice to verify determinism
|
||||||
|
hash1 = _serialize_cache_key(key)
|
||||||
|
hash2 = _serialize_cache_key(key)
|
||||||
|
|
||||||
|
assert hash1 == hash2
|
||||||
|
|
||||||
|
def test_dict_in_cache_key(self):
|
||||||
|
"""Dicts passed directly to _serialize_cache_key should work."""
|
||||||
|
key = {"node_1": {"input": "value"}, "node_2": 42}
|
||||||
|
|
||||||
|
hash1 = _serialize_cache_key(key)
|
||||||
|
hash2 = _serialize_cache_key(key)
|
||||||
|
|
||||||
|
assert hash1 == hash2
|
||||||
|
assert isinstance(hash1, str)
|
||||||
|
assert len(hash1) == 64
|
||||||
|
|
||||||
|
def test_unknown_type_returns_none(self):
|
||||||
|
"""Non-cacheable types should return None (fail-closed)."""
|
||||||
|
class CustomObj:
|
||||||
|
pass
|
||||||
|
assert _serialize_cache_key(CustomObj()) is None
|
||||||
|
|
||||||
|
|
||||||
|
class TestContainsSelfUnequal:
|
||||||
|
"""Test _contains_self_unequal utility function."""
|
||||||
|
|
||||||
|
def test_nan_float_detected(self):
|
||||||
|
"""NaN floats should be detected (not equal to itself)."""
|
||||||
|
assert _contains_self_unequal(float('nan')) is True
|
||||||
|
|
||||||
|
def test_regular_float_not_detected(self):
|
||||||
|
"""Regular floats are equal to themselves."""
|
||||||
|
assert _contains_self_unequal(3.14) is False
|
||||||
|
assert _contains_self_unequal(0.0) is False
|
||||||
|
assert _contains_self_unequal(-1.5) is False
|
||||||
|
|
||||||
|
def test_infinity_not_detected(self):
|
||||||
|
"""Infinity is equal to itself."""
|
||||||
|
assert _contains_self_unequal(float('inf')) is False
|
||||||
|
assert _contains_self_unequal(float('-inf')) is False
|
||||||
|
|
||||||
|
def test_nan_in_list(self):
|
||||||
|
"""NaN in list should be detected."""
|
||||||
|
assert _contains_self_unequal([1, 2, float('nan'), 4]) is True
|
||||||
|
assert _contains_self_unequal([1, 2, 3, 4]) is False
|
||||||
|
|
||||||
|
def test_nan_in_tuple(self):
|
||||||
|
"""NaN in tuple should be detected."""
|
||||||
|
assert _contains_self_unequal((1, float('nan'))) is True
|
||||||
|
assert _contains_self_unequal((1, 2, 3)) is False
|
||||||
|
|
||||||
|
def test_nan_in_frozenset(self):
|
||||||
|
"""NaN in frozenset should be detected."""
|
||||||
|
assert _contains_self_unequal(frozenset([1, float('nan')])) is True
|
||||||
|
assert _contains_self_unequal(frozenset([1, 2, 3])) is False
|
||||||
|
|
||||||
|
def test_nan_in_dict_value(self):
|
||||||
|
"""NaN in dict value should be detected."""
|
||||||
|
assert _contains_self_unequal({"key": float('nan')}) is True
|
||||||
|
assert _contains_self_unequal({"key": 42}) is False
|
||||||
|
|
||||||
|
def test_nan_in_nested_structure(self):
|
||||||
|
"""NaN in deeply nested structure should be detected."""
|
||||||
|
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
|
||||||
|
assert _contains_self_unequal(nested) is True
|
||||||
|
|
||||||
|
def test_non_numeric_types(self):
|
||||||
|
"""Non-numeric types should not be self-unequal."""
|
||||||
|
assert _contains_self_unequal("string") is False
|
||||||
|
assert _contains_self_unequal(None) is False
|
||||||
|
assert _contains_self_unequal(True) is False
|
||||||
|
|
||||||
|
def test_object_with_nan_value_attribute(self):
|
||||||
|
"""Objects wrapping NaN in .value should be detected."""
|
||||||
|
class NanWrapper:
|
||||||
|
def __init__(self):
|
||||||
|
self.value = float('nan')
|
||||||
|
assert _contains_self_unequal(NanWrapper()) is True
|
||||||
|
|
||||||
|
def test_custom_self_unequal_object(self):
|
||||||
|
"""Custom objects where not (x == x) should be detected."""
|
||||||
|
class NeverEqual:
|
||||||
|
def __eq__(self, other):
|
||||||
|
return False
|
||||||
|
assert _contains_self_unequal(NeverEqual()) is True
|
||||||
|
|
||||||
|
|
||||||
|
class TestEstimateValueSize:
|
||||||
|
"""Test _estimate_value_size utility function."""
|
||||||
|
|
||||||
|
def test_empty_outputs(self):
|
||||||
|
"""Empty outputs should have zero size."""
|
||||||
|
value = CacheValue(outputs=[])
|
||||||
|
assert _estimate_value_size(value) == 0
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not _torch_available(),
|
||||||
|
reason="PyTorch not available"
|
||||||
|
)
|
||||||
|
def test_tensor_size_estimation(self):
|
||||||
|
"""Tensor size should be estimated correctly."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
# 1000 float32 elements = 4000 bytes
|
||||||
|
tensor = torch.zeros(1000, dtype=torch.float32)
|
||||||
|
value = CacheValue(outputs=[[tensor]])
|
||||||
|
|
||||||
|
size = _estimate_value_size(value)
|
||||||
|
assert size == 4000
|
||||||
|
|
||||||
|
@pytest.mark.skipif(
|
||||||
|
not _torch_available(),
|
||||||
|
reason="PyTorch not available"
|
||||||
|
)
|
||||||
|
def test_nested_tensor_in_dict(self):
|
||||||
|
"""Tensors nested in dicts should be counted."""
|
||||||
|
import torch
|
||||||
|
|
||||||
|
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
|
||||||
|
value = CacheValue(outputs=[[{"samples": tensor}]])
|
||||||
|
|
||||||
|
size = _estimate_value_size(value)
|
||||||
|
assert size == 400
|
||||||
|
|
||||||
|
|
||||||
|
class TestProviderRegistry:
|
||||||
|
"""Test cache provider registration and retrieval."""
|
||||||
|
|
||||||
|
def setup_method(self):
|
||||||
|
"""Clear providers before each test."""
|
||||||
|
_clear_cache_providers()
|
||||||
|
|
||||||
|
def teardown_method(self):
|
||||||
|
"""Clear providers after each test."""
|
||||||
|
_clear_cache_providers()
|
||||||
|
|
||||||
|
def test_register_provider(self):
|
||||||
|
"""Provider should be registered successfully."""
|
||||||
|
provider = MockCacheProvider()
|
||||||
|
register_cache_provider(provider)
|
||||||
|
|
||||||
|
assert _has_cache_providers() is True
|
||||||
|
providers = _get_cache_providers()
|
||||||
|
assert len(providers) == 1
|
||||||
|
assert providers[0] is provider
|
||||||
|
|
||||||
|
def test_unregister_provider(self):
|
||||||
|
"""Provider should be unregistered successfully."""
|
||||||
|
provider = MockCacheProvider()
|
||||||
|
register_cache_provider(provider)
|
||||||
|
unregister_cache_provider(provider)
|
||||||
|
|
||||||
|
assert _has_cache_providers() is False
|
||||||
|
|
||||||
|
def test_multiple_providers(self):
|
||||||
|
"""Multiple providers can be registered."""
|
||||||
|
provider1 = MockCacheProvider()
|
||||||
|
provider2 = MockCacheProvider()
|
||||||
|
|
||||||
|
register_cache_provider(provider1)
|
||||||
|
register_cache_provider(provider2)
|
||||||
|
|
||||||
|
providers = _get_cache_providers()
|
||||||
|
assert len(providers) == 2
|
||||||
|
|
||||||
|
def test_duplicate_registration_ignored(self):
|
||||||
|
"""Registering same provider twice should be ignored."""
|
||||||
|
provider = MockCacheProvider()
|
||||||
|
|
||||||
|
register_cache_provider(provider)
|
||||||
|
register_cache_provider(provider) # Should be ignored
|
||||||
|
|
||||||
|
providers = _get_cache_providers()
|
||||||
|
assert len(providers) == 1
|
||||||
|
|
||||||
|
def test_clear_providers(self):
|
||||||
|
"""_clear_cache_providers should remove all providers."""
|
||||||
|
provider1 = MockCacheProvider()
|
||||||
|
provider2 = MockCacheProvider()
|
||||||
|
|
||||||
|
register_cache_provider(provider1)
|
||||||
|
register_cache_provider(provider2)
|
||||||
|
_clear_cache_providers()
|
||||||
|
|
||||||
|
assert _has_cache_providers() is False
|
||||||
|
assert len(_get_cache_providers()) == 0
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheContext:
|
||||||
|
"""Test CacheContext dataclass."""
|
||||||
|
|
||||||
|
def test_context_creation(self):
|
||||||
|
"""CacheContext should be created with all fields."""
|
||||||
|
context = CacheContext(
|
||||||
|
node_id="node-456",
|
||||||
|
class_type="KSampler",
|
||||||
|
cache_key_hash="a" * 64,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert context.node_id == "node-456"
|
||||||
|
assert context.class_type == "KSampler"
|
||||||
|
assert context.cache_key_hash == "a" * 64
|
||||||
|
|
||||||
|
|
||||||
|
class TestCacheValue:
|
||||||
|
"""Test CacheValue dataclass."""
|
||||||
|
|
||||||
|
def test_value_creation(self):
|
||||||
|
"""CacheValue should be created with outputs."""
|
||||||
|
outputs = [[{"samples": "tensor_data"}]]
|
||||||
|
value = CacheValue(outputs=outputs)
|
||||||
|
|
||||||
|
assert value.outputs == outputs
|
||||||
|
|
||||||
|
|
||||||
|
class MockCacheProvider(CacheProvider):
|
||||||
|
"""Mock cache provider for testing."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.lookups = []
|
||||||
|
self.stores = []
|
||||||
|
|
||||||
|
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
|
||||||
|
self.lookups.append(context)
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
|
||||||
|
self.stores.append((context, value))
|
||||||
@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
|
|||||||
},
|
},
|
||||||
# JavaScript/CSS scenarios
|
# JavaScript/CSS scenarios
|
||||||
{
|
{
|
||||||
"name": "js_no_cache",
|
"name": "js_no_store",
|
||||||
"path": "/script.js",
|
"path": "/script.js",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "css_no_cache",
|
"name": "css_no_store",
|
||||||
"path": "/styles.css",
|
"path": "/styles.css",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "index_json_no_cache",
|
"name": "index_json_no_store",
|
||||||
"path": "/api/index.json",
|
"path": "/api/index.json",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
"name": "localized_index_json_no_cache",
|
"name": "localized_index_json_no_store",
|
||||||
"path": "/templates/index.zh.json",
|
"path": "/templates/index.zh.json",
|
||||||
"status": 200,
|
"status": 200,
|
||||||
"expected_cache": "no-cache",
|
"expected_cache": "no-store",
|
||||||
"should_have_header": True,
|
"should_have_header": True,
|
||||||
},
|
},
|
||||||
# Non-matching files
|
# Non-matching files
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user