Merge remote-tracking branch 'origin/master' into pysssss/angle-glsl
Some checks failed
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Python Linting / Run Ruff (push) Waiting to run
Python Linting / Run Pylint (push) Waiting to run
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled

This commit is contained in:
pythongosssss 2026-03-24 11:02:39 -07:00
commit 333ff2e8a0
96 changed files with 4595 additions and 788 deletions

103
.github/scripts/check-ai-co-authors.sh vendored Executable file
View File

@ -0,0 +1,103 @@
#!/usr/bin/env bash
# Checks pull request commits for AI agent Co-authored-by trailers.
# Exits non-zero when any are found and prints fix instructions.
set -euo pipefail
base_sha="${1:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
head_sha="${2:?usage: check-ai-co-authors.sh <base_sha> <head_sha>}"
# Known AI coding-agent trailer patterns (case-insensitive).
# Each entry is an extended-regex fragment matched against Co-authored-by lines.
AGENT_PATTERNS=(
# Anthropic — Claude Code / Amp
'noreply@anthropic\.com'
# Cursor
'cursoragent@cursor\.com'
# GitHub Copilot
'copilot-swe-agent\[bot\]'
'copilot@github\.com'
# OpenAI Codex
'noreply@openai\.com'
'codex@openai\.com'
# Aider
'aider@aider\.chat'
# Google — Gemini / Jules
'gemini@google\.com'
'jules@google\.com'
# Windsurf / Codeium
'@codeium\.com'
# Devin
'devin-ai-integration\[bot\]'
'devin@cognition\.ai'
'devin@cognition-labs\.com'
# Amazon Q Developer
'amazon-q-developer'
'@amazon\.com.*[Qq].[Dd]eveloper'
# Cline
'cline-bot'
'cline@cline\.ai'
# Continue
'continue-agent'
'continue@continue\.dev'
# Sourcegraph
'noreply@sourcegraph\.com'
# Generic catch-alls for common agent name patterns
'Co-authored-by:.*\b[Cc]laude\b'
'Co-authored-by:.*\b[Cc]opilot\b'
'Co-authored-by:.*\b[Cc]ursor\b'
'Co-authored-by:.*\b[Cc]odex\b'
'Co-authored-by:.*\b[Gg]emini\b'
'Co-authored-by:.*\b[Aa]ider\b'
'Co-authored-by:.*\b[Dd]evin\b'
'Co-authored-by:.*\b[Ww]indsurf\b'
'Co-authored-by:.*\b[Cc]line\b'
'Co-authored-by:.*\b[Aa]mazon Q\b'
'Co-authored-by:.*\b[Jj]ules\b'
'Co-authored-by:.*\bOpenCode\b'
)
# Build a single alternation regex from all patterns.
regex=""
for pattern in "${AGENT_PATTERNS[@]}"; do
if [[ -n "$regex" ]]; then
regex="${regex}|${pattern}"
else
regex="$pattern"
fi
done
# Collect Co-authored-by lines from every commit in the PR range.
violations=""
while IFS= read -r sha; do
message="$(git log -1 --format='%B' "$sha")"
matched_lines="$(echo "$message" | grep -iE "^Co-authored-by:" || true)"
if [[ -z "$matched_lines" ]]; then
continue
fi
while IFS= read -r line; do
if echo "$line" | grep -iqE "$regex"; then
short="$(git log -1 --format='%h' "$sha")"
violations="${violations} ${short}: ${line}"$'\n'
fi
done <<< "$matched_lines"
done < <(git rev-list "${base_sha}..${head_sha}")
if [[ -n "$violations" ]]; then
echo "::error::AI agent Co-authored-by trailers detected in PR commits."
echo ""
echo "The following commits contain Co-authored-by trailers from AI coding agents:"
echo ""
echo "$violations"
echo "These trailers should be removed before merging."
echo ""
echo "To fix, rewrite the commit messages with:"
echo " git rebase -i ${base_sha}"
echo ""
echo "and remove the Co-authored-by lines, then force-push your branch."
echo ""
echo "If you believe this is a false positive, please open an issue."
exit 1
fi
echo "No AI agent Co-authored-by trailers found."

View File

@ -0,0 +1,19 @@
name: Check AI Co-Authors
on:
pull_request:
branches: ['*']
jobs:
check-ai-co-authors:
name: Check for AI agent co-author trailers
runs-on: ubuntu-latest
steps:
- name: Checkout code
uses: actions/checkout@v4
with:
fetch-depth: 0
- name: Check commits for AI co-author trailers
run: bash .github/scripts/check-ai-co-authors.sh "${{ github.event.pull_request.base.sha }}" "${{ github.event.pull_request.head.sha }}"

View File

@ -38,6 +38,8 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
## Get Started ## Get Started
### Local
#### [Desktop Application](https://www.comfy.org/download) #### [Desktop Application](https://www.comfy.org/download)
- The easiest way to get started. - The easiest way to get started.
- Available on Windows & macOS. - Available on Windows & macOS.
@ -49,8 +51,13 @@ ComfyUI lets you design and execute advanced stable diffusion pipelines using a
#### [Manual Install](#manual-install-windows-linux) #### [Manual Install](#manual-install-windows-linux)
Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend). Supports all operating systems and GPU types (NVIDIA, AMD, Intel, Apple Silicon, Ascend).
## [Examples](https://comfyanonymous.github.io/ComfyUI_examples/) ### Cloud
See what ComfyUI can do with the [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
#### [Comfy Cloud](https://www.comfy.org/cloud)
- Our official paid cloud version for those who can't afford local hardware.
## Examples
See what ComfyUI can do with the [newer template workflows](https://comfy.org/workflows) or old [example workflows](https://comfyanonymous.github.io/ComfyUI_examples/).
## Features ## Features
- Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything. - Nodes/graph/flowchart interface to experiment and create complex Stable Diffusion workflows without needing to code anything.

View File

@ -8,7 +8,7 @@ from alembic import context
config = context.config config = context.config
from app.database.models import Base from app.database.models import Base, NAMING_CONVENTION
target_metadata = Base.metadata target_metadata = Base.metadata
# other values from the config, defined by the needs of env.py, # other values from the config, defined by the needs of env.py,
@ -51,7 +51,10 @@ def run_migrations_online() -> None:
with connectable.connect() as connection: with connectable.connect() as connection:
context.configure( context.configure(
connection=connection, target_metadata=target_metadata connection=connection,
target_metadata=target_metadata,
render_as_batch=True,
naming_convention=NAMING_CONVENTION,
) )
with context.begin_transaction(): with context.begin_transaction():

View File

@ -0,0 +1,98 @@
"""
Add system_metadata and job_id columns to asset_references.
Change preview_id FK from assets.id to asset_references.id.
Revision ID: 0003_add_metadata_job_id
Revises: 0002_merge_to_asset_references
Create Date: 2026-03-09
"""
from alembic import op
import sqlalchemy as sa
from app.database.models import NAMING_CONVENTION
revision = "0003_add_metadata_job_id"
down_revision = "0002_merge_to_asset_references"
branch_labels = None
depends_on = None
def upgrade() -> None:
with op.batch_alter_table("asset_references") as batch_op:
batch_op.add_column(
sa.Column("system_metadata", sa.JSON(), nullable=True)
)
batch_op.add_column(
sa.Column("job_id", sa.String(length=36), nullable=True)
)
# Change preview_id FK from assets.id to asset_references.id (self-ref).
# Existing values are asset-content IDs that won't match reference IDs,
# so null them out first.
op.execute("UPDATE asset_references SET preview_id = NULL WHERE preview_id IS NOT NULL")
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_constraint(
"fk_asset_references_preview_id_assets", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_asset_references",
"asset_references",
["preview_id"],
["id"],
ondelete="SET NULL",
)
batch_op.create_index(
"ix_asset_references_preview_id", ["preview_id"]
)
# Purge any all-null meta rows before adding the constraint
op.execute(
"DELETE FROM asset_reference_meta"
" WHERE val_str IS NULL AND val_num IS NULL AND val_bool IS NULL AND val_json IS NULL"
)
with op.batch_alter_table("asset_reference_meta") as batch_op:
batch_op.create_check_constraint(
"ck_asset_reference_meta_has_value",
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
)
def downgrade() -> None:
# SQLite doesn't reflect CHECK constraints, so we must declare it
# explicitly via table_args for the batch recreate to find it.
# Use the fully-rendered constraint name to avoid the naming convention
# doubling the prefix.
with op.batch_alter_table(
"asset_reference_meta",
table_args=[
sa.CheckConstraint(
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
name="ck_asset_reference_meta_has_value",
),
],
) as batch_op:
batch_op.drop_constraint(
"ck_asset_reference_meta_has_value", type_="check"
)
with op.batch_alter_table(
"asset_references", naming_convention=NAMING_CONVENTION
) as batch_op:
batch_op.drop_index("ix_asset_references_preview_id")
batch_op.drop_constraint(
"fk_asset_references_preview_id_asset_references", type_="foreignkey"
)
batch_op.create_foreign_key(
"fk_asset_references_preview_id_assets",
"assets",
["preview_id"],
["id"],
ondelete="SET NULL",
)
with op.batch_alter_table("asset_references") as batch_op:
batch_op.drop_column("job_id")
batch_op.drop_column("system_metadata")

View File

@ -13,6 +13,7 @@ from pydantic import ValidationError
import folder_paths import folder_paths
from app import user_manager from app import user_manager
from app.assets.api import schemas_in, schemas_out from app.assets.api import schemas_in, schemas_out
from app.assets.services import schemas
from app.assets.api.schemas_in import ( from app.assets.api.schemas_in import (
AssetValidationError, AssetValidationError,
UploadError, UploadError,
@ -38,6 +39,7 @@ from app.assets.services import (
update_asset_metadata, update_asset_metadata,
upload_from_temp_path, upload_from_temp_path,
) )
from app.assets.services.tagging import list_tag_histogram
ROUTES = web.RouteTableDef() ROUTES = web.RouteTableDef()
USER_MANAGER: user_manager.UserManager | None = None USER_MANAGER: user_manager.UserManager | None = None
@ -122,6 +124,61 @@ def _validate_sort_field(requested: str | None) -> str:
return "created_at" return "created_at"
def _build_preview_url_from_view(tags: list[str], user_metadata: dict[str, Any] | None) -> str | None:
"""Build a /api/view preview URL from asset tags and user_metadata filename."""
if not user_metadata:
return None
filename = user_metadata.get("filename")
if not filename:
return None
if "input" in tags:
view_type = "input"
elif "output" in tags:
view_type = "output"
else:
return None
subfolder = ""
if "/" in filename:
subfolder, filename = filename.rsplit("/", 1)
encoded_filename = urllib.parse.quote(filename, safe="")
url = f"/api/view?type={view_type}&filename={encoded_filename}"
if subfolder:
url += f"&subfolder={urllib.parse.quote(subfolder, safe='')}"
return url
def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResult) -> schemas_out.Asset:
"""Build an Asset response from a service result."""
if result.ref.preview_id:
preview_detail = get_asset_detail(result.ref.preview_id)
if preview_detail:
preview_url = _build_preview_url_from_view(preview_detail.tags, preview_detail.ref.user_metadata)
else:
preview_url = None
else:
preview_url = _build_preview_url_from_view(result.tags, result.ref.user_metadata)
return schemas_out.Asset(
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
preview_url=preview_url,
preview_id=result.ref.preview_id,
user_metadata=result.ref.user_metadata or {},
metadata=result.ref.system_metadata,
job_id=result.ref.job_id,
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
created_at=result.ref.created_at,
updated_at=result.ref.updated_at,
last_access_time=result.ref.last_access_time,
)
@ROUTES.head("/api/assets/hash/{hash}") @ROUTES.head("/api/assets/hash/{hash}")
@_require_assets_feature_enabled @_require_assets_feature_enabled
async def head_asset_by_hash(request: web.Request) -> web.Response: async def head_asset_by_hash(request: web.Request) -> web.Response:
@ -164,20 +221,7 @@ async def list_assets_route(request: web.Request) -> web.Response:
order=order, order=order,
) )
summaries = [ summaries = [_build_asset_response(item) for item in result.items]
schemas_out.AssetSummary(
id=item.ref.id,
name=item.ref.name,
asset_hash=item.asset.hash if item.asset else None,
size=int(item.asset.size_bytes) if item.asset else None,
mime_type=item.asset.mime_type if item.asset else None,
tags=item.tags,
created_at=item.ref.created_at,
updated_at=item.ref.updated_at,
last_access_time=item.ref.last_access_time,
)
for item in result.items
]
payload = schemas_out.AssetsList( payload = schemas_out.AssetsList(
assets=summaries, assets=summaries,
@ -207,18 +251,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
{"id": reference_id}, {"id": reference_id},
) )
payload = schemas_out.AssetDetail( payload = _build_asset_response(result)
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
size=int(result.asset.size_bytes) if result.asset else None,
mime_type=result.asset.mime_type if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
)
except ValueError as e: except ValueError as e:
return _build_error_response( return _build_error_response(
404, "ASSET_NOT_FOUND", str(e), {"id": reference_id} 404, "ASSET_NOT_FOUND", str(e), {"id": reference_id}
@ -230,7 +263,7 @@ async def get_asset_route(request: web.Request) -> web.Response:
USER_MANAGER.get_request_user_id(request), USER_MANAGER.get_request_user_id(request),
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content") @ROUTES.get(f"/api/assets/{{id:{UUID_RE}}}/content")
@ -312,32 +345,31 @@ async def create_asset_from_hash_route(request: web.Request) -> web.Response:
400, "INVALID_JSON", "Request body must be valid JSON." 400, "INVALID_JSON", "Request body must be valid JSON."
) )
# Derive name from hash if not provided
name = body.name
if name is None:
name = body.hash.split(":", 1)[1] if ":" in body.hash else body.hash
result = create_from_hash( result = create_from_hash(
hash_str=body.hash, hash_str=body.hash,
name=body.name, name=name,
tags=body.tags, tags=body.tags,
user_metadata=body.user_metadata, user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request), owner_id=USER_MANAGER.get_request_user_id(request),
mime_type=body.mime_type,
preview_id=body.preview_id,
) )
if result is None: if result is None:
return _build_error_response( return _build_error_response(
404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist" 404, "ASSET_NOT_FOUND", f"Asset content {body.hash} does not exist"
) )
asset = _build_asset_response(result)
payload_out = schemas_out.AssetCreated( payload_out = schemas_out.AssetCreated(
id=result.ref.id, **asset.model_dump(),
name=result.ref.name,
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new, created_new=result.created_new,
) )
return web.json_response(payload_out.model_dump(mode="json"), status=201) return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=201)
@ROUTES.post("/api/assets") @ROUTES.post("/api/assets")
@ -358,6 +390,8 @@ async def upload_asset(request: web.Request) -> web.Response:
"name": parsed.provided_name, "name": parsed.provided_name,
"user_metadata": parsed.user_metadata_raw, "user_metadata": parsed.user_metadata_raw,
"hash": parsed.provided_hash, "hash": parsed.provided_hash,
"mime_type": parsed.provided_mime_type,
"preview_id": parsed.provided_preview_id,
} }
) )
except ValidationError as ve: except ValidationError as ve:
@ -386,6 +420,8 @@ async def upload_asset(request: web.Request) -> web.Response:
tags=spec.tags, tags=spec.tags,
user_metadata=spec.user_metadata or {}, user_metadata=spec.user_metadata or {},
owner_id=owner_id, owner_id=owner_id,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
) )
if result is None: if result is None:
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
@ -410,6 +446,8 @@ async def upload_asset(request: web.Request) -> web.Response:
client_filename=parsed.file_client_name, client_filename=parsed.file_client_name,
owner_id=owner_id, owner_id=owner_id,
expected_hash=spec.hash, expected_hash=spec.hash,
mime_type=spec.mime_type,
preview_id=spec.preview_id,
) )
except AssetValidationError as e: except AssetValidationError as e:
delete_temp_file_if_exists(parsed.tmp_path) delete_temp_file_if_exists(parsed.tmp_path)
@ -428,21 +466,13 @@ async def upload_asset(request: web.Request) -> web.Response:
logging.exception("upload_asset failed for owner_id=%s", owner_id) logging.exception("upload_asset failed for owner_id=%s", owner_id)
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
payload = schemas_out.AssetCreated( asset = _build_asset_response(result)
id=result.ref.id, payload_out = schemas_out.AssetCreated(
name=result.ref.name, **asset.model_dump(),
asset_hash=result.asset.hash,
size=int(result.asset.size_bytes),
mime_type=result.asset.mime_type,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
preview_id=result.ref.preview_id,
created_at=result.ref.created_at,
last_access_time=result.ref.last_access_time,
created_new=result.created_new, created_new=result.created_new,
) )
status = 201 if result.created_new else 200 status = 201 if result.created_new else 200
return web.json_response(payload.model_dump(mode="json"), status=status) return web.json_response(payload_out.model_dump(mode="json", exclude_none=True), status=status)
@ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}") @ROUTES.put(f"/api/assets/{{id:{UUID_RE}}}")
@ -464,15 +494,9 @@ async def update_asset_route(request: web.Request) -> web.Response:
name=body.name, name=body.name,
user_metadata=body.user_metadata, user_metadata=body.user_metadata,
owner_id=USER_MANAGER.get_request_user_id(request), owner_id=USER_MANAGER.get_request_user_id(request),
preview_id=body.preview_id,
) )
payload = schemas_out.AssetUpdated( payload = _build_asset_response(result)
id=result.ref.id,
name=result.ref.name,
asset_hash=result.asset.hash if result.asset else None,
tags=result.tags,
user_metadata=result.ref.user_metadata or {},
updated_at=result.ref.updated_at,
)
except PermissionError as pe: except PermissionError as pe:
return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id}) return _build_error_response(403, "FORBIDDEN", str(pe), {"id": reference_id})
except ValueError as ve: except ValueError as ve:
@ -486,7 +510,7 @@ async def update_asset_route(request: web.Request) -> web.Response:
USER_MANAGER.get_request_user_id(request), USER_MANAGER.get_request_user_id(request),
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}") @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}")
@ -555,7 +579,7 @@ async def get_tags(request: web.Request) -> web.Response:
payload = schemas_out.TagsList( payload = schemas_out.TagsList(
tags=tags, total=total, has_more=(query.offset + len(tags)) < total tags=tags, total=total, has_more=(query.offset + len(tags)) < total
) )
return web.json_response(payload.model_dump(mode="json")) return web.json_response(payload.model_dump(mode="json", exclude_none=True))
@ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags") @ROUTES.post(f"/api/assets/{{id:{UUID_RE}}}/tags")
@ -603,7 +627,7 @@ async def add_asset_tags(request: web.Request) -> web.Response:
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags") @ROUTES.delete(f"/api/assets/{{id:{UUID_RE}}}/tags")
@ -650,7 +674,29 @@ async def delete_asset_tags(request: web.Request) -> web.Response:
) )
return _build_error_response(500, "INTERNAL", "Unexpected server error.") return _build_error_response(500, "INTERNAL", "Unexpected server error.")
return web.json_response(payload.model_dump(mode="json"), status=200) return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.get("/api/assets/tags/refine")
@_require_assets_feature_enabled
async def get_tags_refine(request: web.Request) -> web.Response:
"""GET request to get tag histogram for filtered assets."""
query_dict = get_query_dict(request)
try:
q = schemas_in.TagsRefineQuery.model_validate(query_dict)
except ValidationError as ve:
return _build_validation_error_response("INVALID_QUERY", ve)
tag_counts = list_tag_histogram(
owner_id=USER_MANAGER.get_request_user_id(request),
include_tags=q.include_tags,
exclude_tags=q.exclude_tags,
name_contains=q.name_contains,
metadata_filter=q.metadata_filter,
limit=q.limit,
)
payload = schemas_out.TagHistogram(tag_counts=tag_counts)
return web.json_response(payload.model_dump(mode="json", exclude_none=True), status=200)
@ROUTES.post("/api/assets/seed") @ROUTES.post("/api/assets/seed")

View File

@ -45,6 +45,8 @@ class ParsedUpload:
user_metadata_raw: str | None user_metadata_raw: str | None
provided_hash: str | None provided_hash: str | None
provided_hash_exists: bool | None provided_hash_exists: bool | None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
class ListAssetsQuery(BaseModel): class ListAssetsQuery(BaseModel):
@ -98,11 +100,17 @@ class ListAssetsQuery(BaseModel):
class UpdateAssetBody(BaseModel): class UpdateAssetBody(BaseModel):
name: str | None = None name: str | None = None
user_metadata: dict[str, Any] | None = None user_metadata: dict[str, Any] | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
@model_validator(mode="after") @model_validator(mode="after")
def _validate_at_least_one_field(self): def _validate_at_least_one_field(self):
if self.name is None and self.user_metadata is None: if all(
raise ValueError("Provide at least one of: name, user_metadata.") v is None
for v in (self.name, self.user_metadata, self.preview_id)
):
raise ValueError(
"Provide at least one of: name, user_metadata, preview_id."
)
return self return self
@ -110,9 +118,11 @@ class CreateFromHashBody(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
hash: str hash: str
name: str name: str | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict) user_metadata: dict[str, Any] = Field(default_factory=dict)
mime_type: str | None = None
preview_id: str | None = None # references an asset_reference id, not an asset id
@field_validator("hash") @field_validator("hash")
@classmethod @classmethod
@ -138,6 +148,44 @@ class CreateFromHashBody(BaseModel):
return [] return []
class TagsRefineQuery(BaseModel):
include_tags: list[str] = Field(default_factory=list)
exclude_tags: list[str] = Field(default_factory=list)
name_contains: str | None = None
metadata_filter: dict[str, Any] | None = None
limit: conint(ge=1, le=1000) = 100
@field_validator("include_tags", "exclude_tags", mode="before")
@classmethod
def _split_csv_tags(cls, v):
if v is None:
return []
if isinstance(v, str):
return [t.strip() for t in v.split(",") if t.strip()]
if isinstance(v, list):
out: list[str] = []
for item in v:
if isinstance(item, str):
out.extend([t.strip() for t in item.split(",") if t.strip()])
return out
return v
@field_validator("metadata_filter", mode="before")
@classmethod
def _parse_metadata_json(cls, v):
if v is None or isinstance(v, dict):
return v
if isinstance(v, str) and v.strip():
try:
parsed = json.loads(v)
except Exception as e:
raise ValueError(f"metadata_filter must be JSON: {e}") from e
if not isinstance(parsed, dict):
raise ValueError("metadata_filter must be a JSON object")
return parsed
return None
class TagsListQuery(BaseModel): class TagsListQuery(BaseModel):
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
@ -186,21 +234,25 @@ class TagsRemove(TagsAdd):
class UploadAssetSpec(BaseModel): class UploadAssetSpec(BaseModel):
"""Upload Asset operation. """Upload Asset operation.
- tags: ordered; first is root ('models'|'input'|'output'); - tags: optional list; if provided, first is root ('models'|'input'|'output');
if root == 'models', second must be a valid category if root == 'models', second must be a valid category
- name: display name - name: display name
- user_metadata: arbitrary JSON object (optional) - user_metadata: arbitrary JSON object (optional)
- hash: optional canonical 'blake3:<hex>' for validation / fast-path - hash: optional canonical 'blake3:<hex>' for validation / fast-path
- mime_type: optional MIME type override
- preview_id: optional asset_reference ID for preview
Files are stored using the content hash as filename stem. Files are stored using the content hash as filename stem.
""" """
model_config = ConfigDict(extra="ignore", str_strip_whitespace=True) model_config = ConfigDict(extra="ignore", str_strip_whitespace=True)
tags: list[str] = Field(..., min_length=1) tags: list[str] = Field(default_factory=list)
name: str | None = Field(default=None, max_length=512, description="Display Name") name: str | None = Field(default=None, max_length=512, description="Display Name")
user_metadata: dict[str, Any] = Field(default_factory=dict) user_metadata: dict[str, Any] = Field(default_factory=dict)
hash: str | None = Field(default=None) hash: str | None = Field(default=None)
mime_type: str | None = Field(default=None)
preview_id: str | None = Field(default=None) # references an asset_reference id
@field_validator("hash", mode="before") @field_validator("hash", mode="before")
@classmethod @classmethod
@ -279,7 +331,7 @@ class UploadAssetSpec(BaseModel):
@model_validator(mode="after") @model_validator(mode="after")
def _validate_order(self): def _validate_order(self):
if not self.tags: if not self.tags:
raise ValueError("tags must be provided and non-empty") raise ValueError("at least one tag is required for uploads")
root = self.tags[0] root = self.tags[0]
if root not in {"models", "input", "output"}: if root not in {"models", "input", "output"}:
raise ValueError("first tag must be one of: models, input, output") raise ValueError("first tag must be one of: models, input, output")

View File

@ -4,7 +4,10 @@ from typing import Any
from pydantic import BaseModel, ConfigDict, Field, field_serializer from pydantic import BaseModel, ConfigDict, Field, field_serializer
class AssetSummary(BaseModel): class Asset(BaseModel):
"""API view of an asset. Maps to DB ``AssetReference`` joined with its ``Asset`` blob;
``id`` here is the AssetReference id, not the content-addressed Asset id."""
id: str id: str
name: str name: str
asset_hash: str | None = None asset_hash: str | None = None
@ -12,8 +15,14 @@ class AssetSummary(BaseModel):
mime_type: str | None = None mime_type: str | None = None
tags: list[str] = Field(default_factory=list) tags: list[str] = Field(default_factory=list)
preview_url: str | None = None preview_url: str | None = None
created_at: datetime | None = None preview_id: str | None = None # references an asset_reference id, not an asset id
updated_at: datetime | None = None user_metadata: dict[str, Any] = Field(default_factory=dict)
is_immutable: bool = False
metadata: dict[str, Any] | None = None
job_id: str | None = None
prompt_id: str | None = None # deprecated: use job_id
created_at: datetime
updated_at: datetime
last_access_time: datetime | None = None last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True) model_config = ConfigDict(from_attributes=True)
@ -23,50 +32,16 @@ class AssetSummary(BaseModel):
return v.isoformat() if v else None return v.isoformat() if v else None
class AssetCreated(Asset):
created_new: bool
class AssetsList(BaseModel): class AssetsList(BaseModel):
assets: list[AssetSummary] assets: list[Asset]
total: int total: int
has_more: bool has_more: bool
class AssetUpdated(BaseModel):
id: str
name: str
asset_hash: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
updated_at: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("updated_at")
def _serialize_updated_at(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetDetail(BaseModel):
id: str
name: str
asset_hash: str | None = None
size: int | None = None
mime_type: str | None = None
tags: list[str] = Field(default_factory=list)
user_metadata: dict[str, Any] = Field(default_factory=dict)
preview_id: str | None = None
created_at: datetime | None = None
last_access_time: datetime | None = None
model_config = ConfigDict(from_attributes=True)
@field_serializer("created_at", "last_access_time")
def _serialize_datetime(self, v: datetime | None, _info):
return v.isoformat() if v else None
class AssetCreated(AssetDetail):
created_new: bool
class TagUsage(BaseModel): class TagUsage(BaseModel):
name: str name: str
count: int count: int
@ -91,3 +66,7 @@ class TagsRemove(BaseModel):
removed: list[str] = Field(default_factory=list) removed: list[str] = Field(default_factory=list)
not_present: list[str] = Field(default_factory=list) not_present: list[str] = Field(default_factory=list)
total_tags: list[str] = Field(default_factory=list) total_tags: list[str] = Field(default_factory=list)
class TagHistogram(BaseModel):
tag_counts: dict[str, int]

View File

@ -52,6 +52,8 @@ async def parse_multipart_upload(
user_metadata_raw: str | None = None user_metadata_raw: str | None = None
provided_hash: str | None = None provided_hash: str | None = None
provided_hash_exists: bool | None = None provided_hash_exists: bool | None = None
provided_mime_type: str | None = None
provided_preview_id: str | None = None
file_written = 0 file_written = 0
tmp_path: str | None = None tmp_path: str | None = None
@ -128,6 +130,16 @@ async def parse_multipart_upload(
provided_name = (await field.text()) or None provided_name = (await field.text()) or None
elif fname == "user_metadata": elif fname == "user_metadata":
user_metadata_raw = (await field.text()) or None user_metadata_raw = (await field.text()) or None
elif fname == "id":
raise UploadError(
400,
"UNSUPPORTED_FIELD",
"Client-provided 'id' is not supported. Asset IDs are assigned by the server.",
)
elif fname == "mime_type":
provided_mime_type = ((await field.text()) or "").strip() or None
elif fname == "preview_id":
provided_preview_id = ((await field.text()) or "").strip() or None
if not file_present and not (provided_hash and provided_hash_exists): if not file_present and not (provided_hash and provided_hash_exists):
raise UploadError( raise UploadError(
@ -152,6 +164,8 @@ async def parse_multipart_upload(
user_metadata_raw=user_metadata_raw, user_metadata_raw=user_metadata_raw,
provided_hash=provided_hash, provided_hash=provided_hash,
provided_hash_exists=provided_hash_exists, provided_hash_exists=provided_hash_exists,
provided_mime_type=provided_mime_type,
provided_preview_id=provided_preview_id,
) )

View File

@ -45,13 +45,7 @@ class Asset(Base):
passive_deletes=True, passive_deletes=True,
) )
preview_of: Mapped[list[AssetReference]] = relationship( # preview_id on AssetReference is a self-referential FK to asset_references.id
"AssetReference",
back_populates="preview_asset",
primaryjoin=lambda: Asset.id == foreign(AssetReference.preview_id),
foreign_keys=lambda: [AssetReference.preview_id],
viewonly=True,
)
__table_args__ = ( __table_args__ = (
Index("uq_assets_hash", "hash", unique=True), Index("uq_assets_hash", "hash", unique=True),
@ -91,11 +85,15 @@ class AssetReference(Base):
owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="") owner_id: Mapped[str] = mapped_column(String(128), nullable=False, default="")
name: Mapped[str] = mapped_column(String(512), nullable=False) name: Mapped[str] = mapped_column(String(512), nullable=False)
preview_id: Mapped[str | None] = mapped_column( preview_id: Mapped[str | None] = mapped_column(
String(36), ForeignKey("assets.id", ondelete="SET NULL") String(36), ForeignKey("asset_references.id", ondelete="SET NULL")
) )
user_metadata: Mapped[dict[str, Any] | None] = mapped_column( user_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True) JSON(none_as_null=True)
) )
system_metadata: Mapped[dict[str, Any] | None] = mapped_column(
JSON(none_as_null=True), nullable=True, default=None
)
job_id: Mapped[str | None] = mapped_column(String(36), nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column( created_at: Mapped[datetime] = mapped_column(
DateTime(timezone=False), nullable=False, default=get_utc_now DateTime(timezone=False), nullable=False, default=get_utc_now
) )
@ -115,10 +113,10 @@ class AssetReference(Base):
foreign_keys=[asset_id], foreign_keys=[asset_id],
lazy="selectin", lazy="selectin",
) )
preview_asset: Mapped[Asset | None] = relationship( preview_ref: Mapped[AssetReference | None] = relationship(
"Asset", "AssetReference",
back_populates="preview_of",
foreign_keys=[preview_id], foreign_keys=[preview_id],
remote_side=lambda: [AssetReference.id],
) )
metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship( metadata_entries: Mapped[list[AssetReferenceMeta]] = relationship(
@ -152,6 +150,7 @@ class AssetReference(Base):
Index("ix_asset_references_created_at", "created_at"), Index("ix_asset_references_created_at", "created_at"),
Index("ix_asset_references_last_access_time", "last_access_time"), Index("ix_asset_references_last_access_time", "last_access_time"),
Index("ix_asset_references_deleted_at", "deleted_at"), Index("ix_asset_references_deleted_at", "deleted_at"),
Index("ix_asset_references_preview_id", "preview_id"),
Index("ix_asset_references_owner_name", "owner_id", "name"), Index("ix_asset_references_owner_name", "owner_id", "name"),
CheckConstraint( CheckConstraint(
"(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg" "(mtime_ns IS NULL) OR (mtime_ns >= 0)", name="ck_ar_mtime_nonneg"
@ -192,6 +191,10 @@ class AssetReferenceMeta(Base):
Index("ix_asset_reference_meta_key_val_str", "key", "val_str"), Index("ix_asset_reference_meta_key_val_str", "key", "val_str"),
Index("ix_asset_reference_meta_key_val_num", "key", "val_num"), Index("ix_asset_reference_meta_key_val_num", "key", "val_num"),
Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"), Index("ix_asset_reference_meta_key_val_bool", "key", "val_bool"),
CheckConstraint(
"val_str IS NOT NULL OR val_num IS NOT NULL OR val_bool IS NOT NULL OR val_json IS NOT NULL",
name="has_value",
),
) )

View File

@ -31,16 +31,21 @@ from app.assets.database.queries.asset_reference import (
get_unenriched_references, get_unenriched_references,
get_unreferenced_unhashed_asset_ids, get_unreferenced_unhashed_asset_ids,
insert_reference, insert_reference,
list_all_file_paths_by_asset_id,
list_references_by_asset_id, list_references_by_asset_id,
list_references_page, list_references_page,
mark_references_missing_outside_prefixes, mark_references_missing_outside_prefixes,
rebuild_metadata_projection,
reference_exists,
reference_exists_for_asset_id, reference_exists_for_asset_id,
restore_references_by_paths, restore_references_by_paths,
set_reference_metadata, set_reference_metadata,
set_reference_preview, set_reference_preview,
set_reference_system_metadata,
soft_delete_reference_by_id, soft_delete_reference_by_id,
update_reference_access_time, update_reference_access_time,
update_reference_name, update_reference_name,
update_is_missing_by_asset_id,
update_reference_timestamps, update_reference_timestamps,
update_reference_updated_at, update_reference_updated_at,
upsert_reference, upsert_reference,
@ -54,6 +59,7 @@ from app.assets.database.queries.tags import (
bulk_insert_tags_and_meta, bulk_insert_tags_and_meta,
ensure_tags_exist, ensure_tags_exist,
get_reference_tags, get_reference_tags,
list_tag_counts_for_filtered_assets,
list_tags_with_usage, list_tags_with_usage,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
remove_tags_from_reference, remove_tags_from_reference,
@ -97,20 +103,26 @@ __all__ = [
"get_unenriched_references", "get_unenriched_references",
"get_unreferenced_unhashed_asset_ids", "get_unreferenced_unhashed_asset_ids",
"insert_reference", "insert_reference",
"list_all_file_paths_by_asset_id",
"list_references_by_asset_id", "list_references_by_asset_id",
"list_references_page", "list_references_page",
"list_tag_counts_for_filtered_assets",
"list_tags_with_usage", "list_tags_with_usage",
"mark_references_missing_outside_prefixes", "mark_references_missing_outside_prefixes",
"reassign_asset_references", "reassign_asset_references",
"rebuild_metadata_projection",
"reference_exists",
"reference_exists_for_asset_id", "reference_exists_for_asset_id",
"remove_missing_tag_for_asset_id", "remove_missing_tag_for_asset_id",
"remove_tags_from_reference", "remove_tags_from_reference",
"restore_references_by_paths", "restore_references_by_paths",
"set_reference_metadata", "set_reference_metadata",
"set_reference_preview", "set_reference_preview",
"set_reference_system_metadata",
"soft_delete_reference_by_id", "soft_delete_reference_by_id",
"set_reference_tags", "set_reference_tags",
"update_asset_hash_and_mime", "update_asset_hash_and_mime",
"update_is_missing_by_asset_id",
"update_reference_access_time", "update_reference_access_time",
"update_reference_name", "update_reference_name",
"update_reference_timestamps", "update_reference_timestamps",

View File

@ -69,7 +69,7 @@ def upsert_asset(
if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0: if asset.size_bytes != int(size_bytes) and int(size_bytes) > 0:
asset.size_bytes = int(size_bytes) asset.size_bytes = int(size_bytes)
changed = True changed = True
if mime_type and asset.mime_type != mime_type: if mime_type and not asset.mime_type:
asset.mime_type = mime_type asset.mime_type = mime_type
changed = True changed = True
if changed: if changed:
@ -118,7 +118,7 @@ def update_asset_hash_and_mime(
return False return False
if asset_hash is not None: if asset_hash is not None:
asset.hash = asset_hash asset.hash = asset_hash
if mime_type is not None: if mime_type is not None and not asset.mime_type:
asset.mime_type = mime_type asset.mime_type = mime_type
return True return True

View File

@ -10,7 +10,7 @@ from decimal import Decimal
from typing import NamedTuple, Sequence from typing import NamedTuple, Sequence
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import delete, exists, select from sqlalchemy import delete, select
from sqlalchemy.dialects import sqlite from sqlalchemy.dialects import sqlite
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session, noload from sqlalchemy.orm import Session, noload
@ -24,12 +24,14 @@ from app.assets.database.models import (
) )
from app.assets.database.queries.common import ( from app.assets.database.queries.common import (
MAX_BIND_PARAMS, MAX_BIND_PARAMS,
apply_metadata_filter,
apply_tag_filters,
build_prefix_like_conditions, build_prefix_like_conditions,
build_visible_owner_clause, build_visible_owner_clause,
calculate_rows_per_statement, calculate_rows_per_statement,
iter_chunks, iter_chunks,
) )
from app.assets.helpers import escape_sql_like_string, get_utc_now, normalize_tags from app.assets.helpers import escape_sql_like_string, get_utc_now
def _check_is_scalar(v): def _check_is_scalar(v):
@ -44,15 +46,6 @@ def _check_is_scalar(v):
def _scalar_to_row(key: str, ordinal: int, value) -> dict: def _scalar_to_row(key: str, ordinal: int, value) -> dict:
"""Convert a scalar value to a typed projection row.""" """Convert a scalar value to a typed projection row."""
if value is None:
return {
"key": key,
"ordinal": ordinal,
"val_str": None,
"val_num": None,
"val_bool": None,
"val_json": None,
}
if isinstance(value, bool): if isinstance(value, bool):
return {"key": key, "ordinal": ordinal, "val_bool": bool(value)} return {"key": key, "ordinal": ordinal, "val_bool": bool(value)}
if isinstance(value, (int, float, Decimal)): if isinstance(value, (int, float, Decimal)):
@ -66,96 +59,19 @@ def _scalar_to_row(key: str, ordinal: int, value) -> dict:
def convert_metadata_to_rows(key: str, value) -> list[dict]: def convert_metadata_to_rows(key: str, value) -> list[dict]:
"""Turn a metadata key/value into typed projection rows.""" """Turn a metadata key/value into typed projection rows."""
if value is None: if value is None:
return [_scalar_to_row(key, 0, None)] return []
if _check_is_scalar(value): if _check_is_scalar(value):
return [_scalar_to_row(key, 0, value)] return [_scalar_to_row(key, 0, value)]
if isinstance(value, list): if isinstance(value, list):
if all(_check_is_scalar(x) for x in value): if all(_check_is_scalar(x) for x in value):
return [_scalar_to_row(key, i, x) for i, x in enumerate(value)] return [_scalar_to_row(key, i, x) for i, x in enumerate(value) if x is not None]
return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value)] return [{"key": key, "ordinal": i, "val_json": x} for i, x in enumerate(value) if x is not None]
return [{"key": key, "ordinal": 0, "val_json": value}] return [{"key": key, "ordinal": 0, "val_json": value}]
def _apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def _apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
no_row_for_key = sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
null_row = _exists_for_pred(
key,
AssetReferenceMeta.val_json.is_(None),
AssetReferenceMeta.val_str.is_(None),
AssetReferenceMeta.val_num.is_(None),
AssetReferenceMeta.val_bool.is_(None),
)
return sa.or_(no_row_for_key, null_row)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt
def get_reference_by_id( def get_reference_by_id(
@ -212,6 +128,21 @@ def reference_exists_for_asset_id(
return session.execute(q).first() is not None return session.execute(q).first() is not None
def reference_exists(
session: Session,
reference_id: str,
) -> bool:
"""Return True if a reference with the given ID exists (not soft-deleted)."""
q = (
select(sa.literal(True))
.select_from(AssetReference)
.where(AssetReference.id == reference_id)
.where(AssetReference.deleted_at.is_(None))
.limit(1)
)
return session.execute(q).first() is not None
def insert_reference( def insert_reference(
session: Session, session: Session,
asset_id: str, asset_id: str,
@ -336,8 +267,8 @@ def list_references_page(
escaped, esc = escape_sql_like_string(name_contains) escaped, esc = escape_sql_like_string(name_contains)
base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc)) base = base.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
base = _apply_tag_filters(base, include_tags, exclude_tags) base = apply_tag_filters(base, include_tags, exclude_tags)
base = _apply_metadata_filter(base, metadata_filter) base = apply_metadata_filter(base, metadata_filter)
sort = (sort or "created_at").lower() sort = (sort or "created_at").lower()
order = (order or "desc").lower() order = (order or "desc").lower()
@ -366,8 +297,8 @@ def list_references_page(
count_stmt = count_stmt.where( count_stmt = count_stmt.where(
AssetReference.name.ilike(f"%{escaped}%", escape=esc) AssetReference.name.ilike(f"%{escaped}%", escape=esc)
) )
count_stmt = _apply_tag_filters(count_stmt, include_tags, exclude_tags) count_stmt = apply_tag_filters(count_stmt, include_tags, exclude_tags)
count_stmt = _apply_metadata_filter(count_stmt, metadata_filter) count_stmt = apply_metadata_filter(count_stmt, metadata_filter)
total = int(session.execute(count_stmt).scalar_one() or 0) total = int(session.execute(count_stmt).scalar_one() or 0)
refs = session.execute(base).unique().scalars().all() refs = session.execute(base).unique().scalars().all()
@ -379,7 +310,7 @@ def list_references_page(
select(AssetReferenceTag.asset_reference_id, Tag.name) select(AssetReferenceTag.asset_reference_id, Tag.name)
.join(Tag, Tag.name == AssetReferenceTag.tag_name) .join(Tag, Tag.name == AssetReferenceTag.tag_name)
.where(AssetReferenceTag.asset_reference_id.in_(id_list)) .where(AssetReferenceTag.asset_reference_id.in_(id_list))
.order_by(AssetReferenceTag.added_at) .order_by(AssetReferenceTag.tag_name.asc())
) )
for ref_id, tag_name in rows.all(): for ref_id, tag_name in rows.all():
tag_map[ref_id].append(tag_name) tag_map[ref_id].append(tag_name)
@ -492,6 +423,42 @@ def update_reference_updated_at(
) )
def rebuild_metadata_projection(session: Session, ref: AssetReference) -> None:
"""Delete and rebuild AssetReferenceMeta rows from merged system+user metadata.
The merged dict is ``{**system_metadata, **user_metadata}`` so user keys
override system keys of the same name.
"""
session.execute(
delete(AssetReferenceMeta).where(
AssetReferenceMeta.asset_reference_id == ref.id
)
)
session.flush()
merged = {**(ref.system_metadata or {}), **(ref.user_metadata or {})}
if not merged:
return
rows: list[AssetReferenceMeta] = []
for k, v in merged.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetReferenceMeta(
asset_reference_id=ref.id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def set_reference_metadata( def set_reference_metadata(
session: Session, session: Session,
reference_id: str, reference_id: str,
@ -505,33 +472,24 @@ def set_reference_metadata(
ref.updated_at = get_utc_now() ref.updated_at = get_utc_now()
session.flush() session.flush()
session.execute( rebuild_metadata_projection(session, ref)
delete(AssetReferenceMeta).where(
AssetReferenceMeta.asset_reference_id == reference_id
) def set_reference_system_metadata(
) session: Session,
reference_id: str,
system_metadata: dict | None = None,
) -> None:
"""Set system_metadata on a reference and rebuild the merged projection."""
ref = session.get(AssetReference, reference_id)
if not ref:
raise ValueError(f"AssetReference {reference_id} not found")
ref.system_metadata = system_metadata or {}
ref.updated_at = get_utc_now()
session.flush() session.flush()
if not user_metadata: rebuild_metadata_projection(session, ref)
return
rows: list[AssetReferenceMeta] = []
for k, v in user_metadata.items():
for r in convert_metadata_to_rows(k, v):
rows.append(
AssetReferenceMeta(
asset_reference_id=reference_id,
key=r["key"],
ordinal=int(r["ordinal"]),
val_str=r.get("val_str"),
val_num=r.get("val_num"),
val_bool=r.get("val_bool"),
val_json=r.get("val_json"),
)
)
if rows:
session.add_all(rows)
session.flush()
def delete_reference_by_id( def delete_reference_by_id(
@ -571,19 +529,19 @@ def soft_delete_reference_by_id(
def set_reference_preview( def set_reference_preview(
session: Session, session: Session,
reference_id: str, reference_id: str,
preview_asset_id: str | None = None, preview_reference_id: str | None = None,
) -> None: ) -> None:
"""Set or clear preview_id and bump updated_at. Raises on unknown IDs.""" """Set or clear preview_id and bump updated_at. Raises on unknown IDs."""
ref = session.get(AssetReference, reference_id) ref = session.get(AssetReference, reference_id)
if not ref: if not ref:
raise ValueError(f"AssetReference {reference_id} not found") raise ValueError(f"AssetReference {reference_id} not found")
if preview_asset_id is None: if preview_reference_id is None:
ref.preview_id = None ref.preview_id = None
else: else:
if not session.get(Asset, preview_asset_id): if not session.get(AssetReference, preview_reference_id):
raise ValueError(f"Preview Asset {preview_asset_id} not found") raise ValueError(f"Preview AssetReference {preview_reference_id} not found")
ref.preview_id = preview_asset_id ref.preview_id = preview_reference_id
ref.updated_at = get_utc_now() ref.updated_at = get_utc_now()
session.flush() session.flush()
@ -609,6 +567,8 @@ def list_references_by_asset_id(
session.execute( session.execute(
select(AssetReference) select(AssetReference)
.where(AssetReference.asset_id == asset_id) .where(AssetReference.asset_id == asset_id)
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
.order_by(AssetReference.id.asc()) .order_by(AssetReference.id.asc())
) )
.scalars() .scalars()
@ -616,6 +576,25 @@ def list_references_by_asset_id(
) )
def list_all_file_paths_by_asset_id(
session: Session,
asset_id: str,
) -> list[str]:
"""Return every file_path for an asset, including soft-deleted/missing refs.
Used for orphan cleanup where all on-disk files must be removed.
"""
return list(
session.execute(
select(AssetReference.file_path)
.where(AssetReference.asset_id == asset_id)
.where(AssetReference.file_path.isnot(None))
)
.scalars()
.all()
)
def upsert_reference( def upsert_reference(
session: Session, session: Session,
asset_id: str, asset_id: str,
@ -855,6 +834,22 @@ def bulk_update_is_missing(
return total return total
def update_is_missing_by_asset_id(
session: Session, asset_id: str, value: bool
) -> int:
"""Set is_missing flag for ALL references belonging to an asset.
Returns: Number of rows updated
"""
result = session.execute(
sa.update(AssetReference)
.where(AssetReference.asset_id == asset_id)
.where(AssetReference.deleted_at.is_(None))
.values(is_missing=value)
)
return result.rowcount
def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int: def delete_references_by_ids(session: Session, reference_ids: list[str]) -> int:
"""Delete references by their IDs. """Delete references by their IDs.

View File

@ -1,12 +1,14 @@
"""Shared utilities for database query modules.""" """Shared utilities for database query modules."""
import os import os
from typing import Iterable from decimal import Decimal
from typing import Iterable, Sequence
import sqlalchemy as sa import sqlalchemy as sa
from sqlalchemy import exists
from app.assets.database.models import AssetReference from app.assets.database.models import AssetReference, AssetReferenceMeta, AssetReferenceTag
from app.assets.helpers import escape_sql_like_string from app.assets.helpers import escape_sql_like_string, normalize_tags
MAX_BIND_PARAMS = 800 MAX_BIND_PARAMS = 800
@ -52,3 +54,74 @@ def build_prefix_like_conditions(
escaped, esc = escape_sql_like_string(base) escaped, esc = escape_sql_like_string(base)
conds.append(AssetReference.file_path.like(escaped + "%", escape=esc)) conds.append(AssetReference.file_path.like(escaped + "%", escape=esc))
return conds return conds
def apply_tag_filters(
stmt: sa.sql.Select,
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
) -> sa.sql.Select:
"""include_tags: every tag must be present; exclude_tags: none may be present."""
include_tags = normalize_tags(include_tags)
exclude_tags = normalize_tags(exclude_tags)
if include_tags:
for tag_name in include_tags:
stmt = stmt.where(
exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name == tag_name)
)
)
if exclude_tags:
stmt = stmt.where(
~exists().where(
(AssetReferenceTag.asset_reference_id == AssetReference.id)
& (AssetReferenceTag.tag_name.in_(exclude_tags))
)
)
return stmt
def apply_metadata_filter(
stmt: sa.sql.Select,
metadata_filter: dict | None = None,
) -> sa.sql.Select:
"""Apply filters using asset_reference_meta projection table."""
if not metadata_filter:
return stmt
def _exists_for_pred(key: str, *preds) -> sa.sql.ClauseElement:
return sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
*preds,
)
def _exists_clause_for_value(key: str, value) -> sa.sql.ClauseElement:
if value is None:
return sa.not_(
sa.exists().where(
AssetReferenceMeta.asset_reference_id == AssetReference.id,
AssetReferenceMeta.key == key,
)
)
if isinstance(value, bool):
return _exists_for_pred(key, AssetReferenceMeta.val_bool == bool(value))
if isinstance(value, (int, float, Decimal)):
num = value if isinstance(value, Decimal) else Decimal(str(value))
return _exists_for_pred(key, AssetReferenceMeta.val_num == num)
if isinstance(value, str):
return _exists_for_pred(key, AssetReferenceMeta.val_str == value)
return _exists_for_pred(key, AssetReferenceMeta.val_json == value)
for k, v in metadata_filter.items():
if isinstance(v, list):
ors = [_exists_clause_for_value(k, elem) for elem in v]
if ors:
stmt = stmt.where(sa.or_(*ors))
else:
stmt = stmt.where(_exists_clause_for_value(k, v))
return stmt

View File

@ -8,12 +8,15 @@ from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from app.assets.database.models import ( from app.assets.database.models import (
Asset,
AssetReference, AssetReference,
AssetReferenceMeta, AssetReferenceMeta,
AssetReferenceTag, AssetReferenceTag,
Tag, Tag,
) )
from app.assets.database.queries.common import ( from app.assets.database.queries.common import (
apply_metadata_filter,
apply_tag_filters,
build_visible_owner_clause, build_visible_owner_clause,
iter_row_chunks, iter_row_chunks,
) )
@ -72,9 +75,9 @@ def get_reference_tags(session: Session, reference_id: str) -> list[str]:
tag_name tag_name
for (tag_name,) in ( for (tag_name,) in (
session.execute( session.execute(
select(AssetReferenceTag.tag_name).where( select(AssetReferenceTag.tag_name)
AssetReferenceTag.asset_reference_id == reference_id .where(AssetReferenceTag.asset_reference_id == reference_id)
) .order_by(AssetReferenceTag.tag_name.asc())
) )
).all() ).all()
] ]
@ -117,7 +120,7 @@ def set_reference_tags(
) )
session.flush() session.flush()
return SetTagsResult(added=to_add, removed=to_remove, total=desired) return SetTagsResult(added=sorted(to_add), removed=sorted(to_remove), total=sorted(desired))
def add_tags_to_reference( def add_tags_to_reference(
@ -272,6 +275,12 @@ def list_tags_with_usage(
.select_from(AssetReferenceTag) .select_from(AssetReferenceTag)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id)) .where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None)) .where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name) .group_by(AssetReferenceTag.tag_name)
.subquery() .subquery()
@ -308,6 +317,12 @@ def list_tags_with_usage(
select(AssetReferenceTag.tag_name) select(AssetReferenceTag.tag_name)
.join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id) .join(AssetReference, AssetReference.id == AssetReferenceTag.asset_reference_id)
.where(build_visible_owner_clause(owner_id)) .where(build_visible_owner_clause(owner_id))
.where(
sa.or_(
AssetReference.is_missing == False, # noqa: E712
AssetReferenceTag.tag_name == "missing",
)
)
.where(AssetReference.deleted_at.is_(None)) .where(AssetReference.deleted_at.is_(None))
.group_by(AssetReferenceTag.tag_name) .group_by(AssetReferenceTag.tag_name)
) )
@ -320,6 +335,53 @@ def list_tags_with_usage(
return rows_norm, int(total or 0) return rows_norm, int(total or 0)
def list_tag_counts_for_filtered_assets(
session: Session,
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
"""Return tag counts for assets matching the given filters.
Uses the same filtering logic as list_references_page but returns
{tag_name: count} instead of paginated references.
"""
# Build a subquery of matching reference IDs
ref_sq = (
select(AssetReference.id)
.join(Asset, Asset.id == AssetReference.asset_id)
.where(build_visible_owner_clause(owner_id))
.where(AssetReference.is_missing == False) # noqa: E712
.where(AssetReference.deleted_at.is_(None))
)
if name_contains:
escaped, esc = escape_sql_like_string(name_contains)
ref_sq = ref_sq.where(AssetReference.name.ilike(f"%{escaped}%", escape=esc))
ref_sq = apply_tag_filters(ref_sq, include_tags, exclude_tags)
ref_sq = apply_metadata_filter(ref_sq, metadata_filter)
ref_sq = ref_sq.subquery()
# Count tags across those references
q = (
select(
AssetReferenceTag.tag_name,
func.count(AssetReferenceTag.asset_reference_id).label("cnt"),
)
.where(AssetReferenceTag.asset_reference_id.in_(select(ref_sq.c.id)))
.group_by(AssetReferenceTag.tag_name)
.order_by(func.count(AssetReferenceTag.asset_reference_id).desc(), AssetReferenceTag.tag_name.asc())
.limit(limit)
)
rows = session.execute(q).all()
return {tag_name: int(cnt) for tag_name, cnt in rows}
def bulk_insert_tags_and_meta( def bulk_insert_tags_and_meta(
session: Session, session: Session,
tag_rows: list[dict], tag_rows: list[dict],

View File

@ -18,7 +18,7 @@ from app.assets.database.queries import (
mark_references_missing_outside_prefixes, mark_references_missing_outside_prefixes,
reassign_asset_references, reassign_asset_references,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
set_reference_metadata, set_reference_system_metadata,
update_asset_hash_and_mime, update_asset_hash_and_mime,
) )
from app.assets.services.bulk_ingest import ( from app.assets.services.bulk_ingest import (
@ -490,8 +490,8 @@ def enrich_asset(
logging.warning("Failed to hash %s: %s", file_path, e) logging.warning("Failed to hash %s: %s", file_path, e)
if extract_metadata and metadata: if extract_metadata and metadata:
user_metadata = metadata.to_user_metadata() system_metadata = metadata.to_user_metadata()
set_reference_metadata(session, reference_id, user_metadata) set_reference_system_metadata(session, reference_id, system_metadata)
if full_hash: if full_hash:
existing = get_asset_by_hash(session, full_hash) existing = get_asset_by_hash(session, full_hash)

View File

@ -16,10 +16,12 @@ from app.assets.database.queries import (
get_reference_by_id, get_reference_by_id,
get_reference_with_owner_check, get_reference_with_owner_check,
list_references_page, list_references_page,
list_all_file_paths_by_asset_id,
list_references_by_asset_id, list_references_by_asset_id,
set_reference_metadata, set_reference_metadata,
set_reference_preview, set_reference_preview,
set_reference_tags, set_reference_tags,
update_asset_hash_and_mime,
update_reference_access_time, update_reference_access_time,
update_reference_name, update_reference_name,
update_reference_updated_at, update_reference_updated_at,
@ -67,6 +69,8 @@ def update_asset_metadata(
user_metadata: UserMetadata = None, user_metadata: UserMetadata = None,
tag_origin: str = "manual", tag_origin: str = "manual",
owner_id: str = "", owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> AssetDetailResult: ) -> AssetDetailResult:
with create_session() as session: with create_session() as session:
ref = get_reference_with_owner_check(session, reference_id, owner_id) ref = get_reference_with_owner_check(session, reference_id, owner_id)
@ -103,6 +107,21 @@ def update_asset_metadata(
) )
touched = True touched = True
if mime_type is not None:
updated = update_asset_hash_and_mime(
session, asset_id=ref.asset_id, mime_type=mime_type
)
if updated:
touched = True
if preview_id is not None:
set_reference_preview(
session,
reference_id=reference_id,
preview_reference_id=preview_id,
)
touched = True
if touched and user_metadata is None: if touched and user_metadata is None:
update_reference_updated_at(session, reference_id=reference_id) update_reference_updated_at(session, reference_id=reference_id)
@ -159,11 +178,9 @@ def delete_asset_reference(
session.commit() session.commit()
return True return True
# Orphaned asset - delete it and its files # Orphaned asset - gather ALL file paths (including
refs = list_references_by_asset_id(session, asset_id=asset_id) # soft-deleted / missing refs) so their on-disk files get cleaned up.
file_paths = [ file_paths = list_all_file_paths_by_asset_id(session, asset_id=asset_id)
r.file_path for r in (refs or []) if getattr(r, "file_path", None)
]
# Also include the just-deleted file path # Also include the just-deleted file path
if file_path: if file_path:
file_paths.append(file_path) file_paths.append(file_path)
@ -185,7 +202,7 @@ def delete_asset_reference(
def set_asset_preview( def set_asset_preview(
reference_id: str, reference_id: str,
preview_asset_id: str | None = None, preview_reference_id: str | None = None,
owner_id: str = "", owner_id: str = "",
) -> AssetDetailResult: ) -> AssetDetailResult:
with create_session() as session: with create_session() as session:
@ -194,7 +211,7 @@ def set_asset_preview(
set_reference_preview( set_reference_preview(
session, session,
reference_id=reference_id, reference_id=reference_id,
preview_asset_id=preview_asset_id, preview_reference_id=preview_reference_id,
) )
result = fetch_reference_asset_and_tags( result = fetch_reference_asset_and_tags(
@ -263,6 +280,47 @@ def list_assets_page(
return ListAssetsResult(items=items, total=total) return ListAssetsResult(items=items, total=total)
def resolve_hash_to_path(
asset_hash: str,
owner_id: str = "",
) -> DownloadResolutionResult | None:
"""Resolve a blake3 hash to an on-disk file path.
Only references visible to *owner_id* are considered (owner-less
references are always visible).
Returns a DownloadResolutionResult with abs_path, content_type, and
download_name, or None if no asset or live path is found.
"""
with create_session() as session:
asset = queries_get_asset_by_hash(session, asset_hash)
if not asset:
return None
refs = list_references_by_asset_id(session, asset_id=asset.id)
visible = [
r for r in refs
if r.owner_id == "" or r.owner_id == owner_id
]
abs_path = select_best_live_path(visible)
if not abs_path:
return None
display_name = os.path.basename(abs_path)
for ref in visible:
if ref.file_path == abs_path and ref.name:
display_name = ref.name
break
ctype = (
asset.mime_type
or mimetypes.guess_type(display_name)[0]
or "application/octet-stream"
)
return DownloadResolutionResult(
abs_path=abs_path,
content_type=ctype,
download_name=display_name,
)
def resolve_asset_for_download( def resolve_asset_for_download(
reference_id: str, reference_id: str,
owner_id: str = "", owner_id: str = "",

View File

@ -11,13 +11,14 @@ from app.assets.database.queries import (
add_tags_to_reference, add_tags_to_reference,
fetch_reference_and_asset, fetch_reference_and_asset,
get_asset_by_hash, get_asset_by_hash,
get_existing_asset_ids,
get_reference_by_file_path, get_reference_by_file_path,
get_reference_tags, get_reference_tags,
get_or_create_reference, get_or_create_reference,
reference_exists,
remove_missing_tag_for_asset_id, remove_missing_tag_for_asset_id,
set_reference_metadata, set_reference_metadata,
set_reference_tags, set_reference_tags,
update_asset_hash_and_mime,
upsert_asset, upsert_asset,
upsert_reference, upsert_reference,
validate_tags_exist, validate_tags_exist,
@ -26,6 +27,7 @@ from app.assets.helpers import normalize_tags
from app.assets.services.file_utils import get_size_and_mtime_ns from app.assets.services.file_utils import get_size_and_mtime_ns
from app.assets.services.path_utils import ( from app.assets.services.path_utils import (
compute_relative_filename, compute_relative_filename,
get_name_and_tags_from_asset_path,
resolve_destination_from_tags, resolve_destination_from_tags,
validate_path_within_base, validate_path_within_base,
) )
@ -65,7 +67,7 @@ def _ingest_file_from_path(
with create_session() as session: with create_session() as session:
if preview_id: if preview_id:
if preview_id not in get_existing_asset_ids(session, [preview_id]): if not reference_exists(session, preview_id):
preview_id = None preview_id = None
asset, asset_created, asset_updated = upsert_asset( asset, asset_created, asset_updated = upsert_asset(
@ -135,6 +137,8 @@ def _register_existing_asset(
tags: list[str] | None = None, tags: list[str] | None = None,
tag_origin: str = "manual", tag_origin: str = "manual",
owner_id: str = "", owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> RegisterAssetResult: ) -> RegisterAssetResult:
user_metadata = user_metadata or {} user_metadata = user_metadata or {}
@ -143,14 +147,25 @@ def _register_existing_asset(
if not asset: if not asset:
raise ValueError(f"No asset with hash {asset_hash}") raise ValueError(f"No asset with hash {asset_hash}")
if mime_type and not asset.mime_type:
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=mime_type)
if preview_id:
if not reference_exists(session, preview_id):
preview_id = None
ref, ref_created = get_or_create_reference( ref, ref_created = get_or_create_reference(
session, session,
asset_id=asset.id, asset_id=asset.id,
owner_id=owner_id, owner_id=owner_id,
name=name, name=name,
preview_id=preview_id,
) )
if not ref_created: if not ref_created:
if preview_id and ref.preview_id != preview_id:
ref.preview_id = preview_id
tag_names = get_reference_tags(session, reference_id=ref.id) tag_names = get_reference_tags(session, reference_id=ref.id)
result = RegisterAssetResult( result = RegisterAssetResult(
ref=extract_reference_data(ref), ref=extract_reference_data(ref),
@ -242,6 +257,8 @@ def upload_from_temp_path(
client_filename: str | None = None, client_filename: str | None = None,
owner_id: str = "", owner_id: str = "",
expected_hash: str | None = None, expected_hash: str | None = None,
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult: ) -> UploadResult:
try: try:
digest, _ = hashing.compute_blake3_hash(temp_path) digest, _ = hashing.compute_blake3_hash(temp_path)
@ -270,6 +287,8 @@ def upload_from_temp_path(
tags=tags or [], tags=tags or [],
tag_origin="manual", tag_origin="manual",
owner_id=owner_id, owner_id=owner_id,
mime_type=mime_type,
preview_id=preview_id,
) )
return UploadResult( return UploadResult(
ref=result.ref, ref=result.ref,
@ -291,7 +310,7 @@ def upload_from_temp_path(
dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename)) dest_abs = os.path.abspath(os.path.join(dest_dir, hashed_basename))
validate_path_within_base(dest_abs, base_dir) validate_path_within_base(dest_abs, base_dir)
content_type = ( content_type = mime_type or (
mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0] mimetypes.guess_type(os.path.basename(src_for_ext), strict=False)[0]
or mimetypes.guess_type(hashed_basename, strict=False)[0] or mimetypes.guess_type(hashed_basename, strict=False)[0]
or "application/octet-stream" or "application/octet-stream"
@ -315,7 +334,7 @@ def upload_from_temp_path(
mime_type=content_type, mime_type=content_type,
info_name=_sanitize_filename(name or client_filename, fallback=digest), info_name=_sanitize_filename(name or client_filename, fallback=digest),
owner_id=owner_id, owner_id=owner_id,
preview_id=None, preview_id=preview_id,
user_metadata=user_metadata or {}, user_metadata=user_metadata or {},
tags=tags, tags=tags,
tag_origin="manual", tag_origin="manual",
@ -342,30 +361,99 @@ def upload_from_temp_path(
) )
def register_file_in_place(
abs_path: str,
name: str,
tags: list[str],
owner_id: str = "",
mime_type: str | None = None,
) -> UploadResult:
"""Register an already-saved file in the asset database without moving it.
Tags are derived from the filesystem path (root category + subfolder names),
merged with any caller-provided tags, matching the behavior of the scanner.
If the path is not under a known root, only the caller-provided tags are used.
"""
try:
_, path_tags = get_name_and_tags_from_asset_path(abs_path)
except ValueError:
path_tags = []
merged_tags = normalize_tags([*path_tags, *tags])
try:
digest, _ = hashing.compute_blake3_hash(abs_path)
except ImportError as e:
raise DependencyMissingError(str(e))
except Exception as e:
raise RuntimeError(f"failed to hash file: {e}")
asset_hash = "blake3:" + digest
size_bytes, mtime_ns = get_size_and_mtime_ns(abs_path)
content_type = mime_type or (
mimetypes.guess_type(abs_path, strict=False)[0]
or "application/octet-stream"
)
ingest_result = _ingest_file_from_path(
abs_path=abs_path,
asset_hash=asset_hash,
size_bytes=size_bytes,
mtime_ns=mtime_ns,
mime_type=content_type,
info_name=_sanitize_filename(name, fallback=digest),
owner_id=owner_id,
tags=merged_tags,
tag_origin="upload",
require_existing_tags=False,
)
reference_id = ingest_result.reference_id
if not reference_id:
raise RuntimeError("failed to create asset reference")
with create_session() as session:
pair = fetch_reference_and_asset(
session, reference_id=reference_id, owner_id=owner_id
)
if not pair:
raise RuntimeError("inconsistent DB state after ingest")
ref, asset = pair
tag_names = get_reference_tags(session, reference_id=ref.id)
return UploadResult(
ref=extract_reference_data(ref),
asset=extract_asset_data(asset),
tags=tag_names,
created_new=ingest_result.asset_created,
)
def create_from_hash( def create_from_hash(
hash_str: str, hash_str: str,
name: str, name: str,
tags: list[str] | None = None, tags: list[str] | None = None,
user_metadata: dict | None = None, user_metadata: dict | None = None,
owner_id: str = "", owner_id: str = "",
mime_type: str | None = None,
preview_id: str | None = None,
) -> UploadResult | None: ) -> UploadResult | None:
canonical = hash_str.strip().lower() canonical = hash_str.strip().lower()
with create_session() as session: try:
asset = get_asset_by_hash(session, asset_hash=canonical) result = _register_existing_asset(
if not asset: asset_hash=canonical,
return None name=_sanitize_filename(
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical
result = _register_existing_asset( ),
asset_hash=canonical, user_metadata=user_metadata or {},
name=_sanitize_filename( tags=tags or [],
name, fallback=canonical.split(":", 1)[1] if ":" in canonical else canonical tag_origin="manual",
), owner_id=owner_id,
user_metadata=user_metadata or {}, mime_type=mime_type,
tags=tags or [], preview_id=preview_id,
tag_origin="manual", )
owner_id=owner_id, except ValueError:
) logging.warning("create_from_hash: no asset found for hash %s", canonical)
return None
return UploadResult( return UploadResult(
ref=result.ref, ref=result.ref,

View File

@ -25,7 +25,9 @@ class ReferenceData:
preview_id: str | None preview_id: str | None
created_at: datetime created_at: datetime
updated_at: datetime updated_at: datetime
last_access_time: datetime | None system_metadata: dict[str, Any] | None = None
job_id: str | None = None
last_access_time: datetime | None = None
@dataclass(frozen=True) @dataclass(frozen=True)
@ -93,6 +95,8 @@ def extract_reference_data(ref: AssetReference) -> ReferenceData:
file_path=ref.file_path, file_path=ref.file_path,
user_metadata=ref.user_metadata, user_metadata=ref.user_metadata,
preview_id=ref.preview_id, preview_id=ref.preview_id,
system_metadata=ref.system_metadata,
job_id=ref.job_id,
created_at=ref.created_at, created_at=ref.created_at,
updated_at=ref.updated_at, updated_at=ref.updated_at,
last_access_time=ref.last_access_time, last_access_time=ref.last_access_time,

View File

@ -1,3 +1,5 @@
from typing import Sequence
from app.assets.database.queries import ( from app.assets.database.queries import (
AddTagsResult, AddTagsResult,
RemoveTagsResult, RemoveTagsResult,
@ -6,6 +8,7 @@ from app.assets.database.queries import (
list_tags_with_usage, list_tags_with_usage,
remove_tags_from_reference, remove_tags_from_reference,
) )
from app.assets.database.queries.tags import list_tag_counts_for_filtered_assets
from app.assets.services.schemas import TagUsage from app.assets.services.schemas import TagUsage
from app.database.db import create_session from app.database.db import create_session
@ -73,3 +76,23 @@ def list_tags(
) )
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
def list_tag_histogram(
owner_id: str = "",
include_tags: Sequence[str] | None = None,
exclude_tags: Sequence[str] | None = None,
name_contains: str | None = None,
metadata_filter: dict | None = None,
limit: int = 100,
) -> dict[str, int]:
with create_session() as session:
return list_tag_counts_for_filtered_assets(
session,
owner_id=owner_id,
include_tags=include_tags,
exclude_tags=exclude_tags,
name_contains=name_contains,
metadata_filter=metadata_filter,
limit=limit,
)

View File

@ -1,9 +1,18 @@
from typing import Any from typing import Any
from datetime import datetime from datetime import datetime
from sqlalchemy import MetaData
from sqlalchemy.orm import DeclarativeBase from sqlalchemy.orm import DeclarativeBase
NAMING_CONVENTION = {
"ix": "ix_%(table_name)s_%(column_0_N_name)s",
"uq": "uq_%(table_name)s_%(column_0_N_name)s",
"ck": "ck_%(table_name)s_%(constraint_name)s",
"fk": "fk_%(table_name)s_%(column_0_name)s_%(referred_table_name)s",
"pk": "pk_%(table_name)s",
}
class Base(DeclarativeBase): class Base(DeclarativeBase):
pass metadata = MetaData(naming_convention=NAMING_CONVENTION)
def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]: def to_dict(obj: Any, include_none: bool = False) -> dict[str, Any]:
fields = obj.__table__.columns.keys() fields = obj.__table__.columns.keys()

View File

@ -6,6 +6,7 @@ import uuid
import glob import glob
import shutil import shutil
import logging import logging
import tempfile
from aiohttp import web from aiohttp import web
from urllib import parse from urllib import parse
from comfy.cli_args import args from comfy.cli_args import args
@ -377,8 +378,15 @@ class UserManager():
try: try:
body = await request.read() body = await request.read()
with open(path, "wb") as f: dir_name = os.path.dirname(path)
f.write(body) fd, tmp_path = tempfile.mkstemp(dir=dir_name)
try:
with os.fdopen(fd, "wb") as f:
f.write(body)
os.replace(tmp_path, path)
except:
os.unlink(tmp_path)
raise
except OSError as e: except OSError as e:
logging.warning(f"Error saving file '{path}': {e}") logging.warning(f"Error saving file '{path}': {e}")
return web.Response( return web.Response(

View File

@ -83,6 +83,8 @@ fpte_group.add_argument("--fp16-text-enc", action="store_true", help="Store text
fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.") fpte_group.add_argument("--fp32-text-enc", action="store_true", help="Store text encoder weights in fp32.")
fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.") fpte_group.add_argument("--bf16-text-enc", action="store_true", help="Store text encoder weights in bf16.")
parser.add_argument("--fp16-intermediates", action="store_true", help="Experimental: Use fp16 for intermediate tensors between nodes instead of fp32.")
parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.") parser.add_argument("--force-channels-last", action="store_true", help="Force channels last format when inferencing the models.")
parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.") parser.add_argument("--directml", type=int, nargs="?", metavar="DIRECTML_DEVICE", const=-1, help="Use torch-directml.")
@ -147,6 +149,7 @@ parser.add_argument("--reserve-vram", type=float, default=None, help="Set the am
parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.") parser.add_argument("--async-offload", nargs='?', const=2, type=int, default=None, metavar="NUM_STREAMS", help="Use async weight offloading. An optional argument controls the amount of offload streams. Default is 2. Enabled by default on Nvidia.")
parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.") parser.add_argument("--disable-async-offload", action="store_true", help="Disable async weight offloading.")
parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.") parser.add_argument("--disable-dynamic-vram", action="store_true", help="Disable dynamic VRAM and use estimate based model loading.")
parser.add_argument("--enable-dynamic-vram", action="store_true", help="Enable dynamic VRAM on systems where it's not enabled by default.")
parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.") parser.add_argument("--force-non-blocking", action="store_true", help="Force ComfyUI to use non-blocking operations for all applicable tensors. This may improve performance on some non-Nvidia systems but can cause issues with some workflows.")
@ -260,4 +263,6 @@ else:
args.fast = set(args.fast) args.fast = set(args.fast)
def enables_dynamic_vram(): def enables_dynamic_vram():
if args.enable_dynamic_vram:
return True
return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu return not args.disable_dynamic_vram and not args.highvram and not args.gpu_only and not args.novram and not args.cpu

View File

@ -176,8 +176,8 @@ class InputTypeOptions(TypedDict):
"""COMBO type only. Specifies the configuration for a multi-select widget. """COMBO type only. Specifies the configuration for a multi-select widget.
Available after ComfyUI frontend v1.13.4 Available after ComfyUI frontend v1.13.4
https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987""" https://github.com/Comfy-Org/ComfyUI_frontend/pull/2987"""
gradient_stops: NotRequired[list[list[float]]] gradient_stops: NotRequired[list[dict]]
"""Gradient color stops for gradientslider display mode. Each stop is [offset, r, g, b] (``FLOAT``).""" """Gradient color stops for gradientslider display mode. Each stop is {"offset": float, "color": [r, g, b]}."""
class HiddenInputTypeDict(TypedDict): class HiddenInputTypeDict(TypedDict):

View File

@ -93,6 +93,50 @@ class IndexListCallbacks:
return {} return {}
def slice_cond(cond_value, window: IndexListContextWindow, x_in: torch.Tensor, device, temporal_dim: int, temporal_scale: int=1, temporal_offset: int=0, retain_index_list: list[int]=[]):
if not (hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor)):
return None
cond_tensor = cond_value.cond
if temporal_dim >= cond_tensor.ndim:
return None
cond_size = cond_tensor.size(temporal_dim)
if temporal_scale == 1:
expected_size = x_in.size(window.dim) - temporal_offset
if cond_size != expected_size:
return None
if temporal_offset == 0 and temporal_scale == 1:
sliced = window.get_tensor(cond_tensor, device, dim=temporal_dim, retain_index_list=retain_index_list)
return cond_value._copy_with(sliced)
# skip leading latent positions that have no corresponding conditioning (e.g. reference frames)
if temporal_offset > 0:
indices = [i - temporal_offset for i in window.index_list[temporal_offset:]]
indices = [i for i in indices if 0 <= i]
else:
indices = list(window.index_list)
if not indices:
return None
if temporal_scale > 1:
scaled = []
for i in indices:
for k in range(temporal_scale):
si = i * temporal_scale + k
if si < cond_size:
scaled.append(si)
indices = scaled
if not indices:
return None
idx = tuple([slice(None)] * temporal_dim + [indices])
sliced = cond_tensor[idx].to(device)
return cond_value._copy_with(sliced)
@dataclass @dataclass
class ContextSchedule: class ContextSchedule:
name: str name: str
@ -177,10 +221,17 @@ class IndexListContextHandler(ContextHandlerABC):
new_cond_item[cond_key] = result new_cond_item[cond_key] = result
handled = True handled = True
break break
if not handled and self._model is not None:
result = self._model.resize_cond_for_context_window(
cond_key, cond_value, window, x_in, device,
retain_index_list=self.cond_retain_index_list)
if result is not None:
new_cond_item[cond_key] = result
handled = True
if handled: if handled:
continue continue
if isinstance(cond_value, torch.Tensor): if isinstance(cond_value, torch.Tensor):
if (self.dim < cond_value.ndim and cond_value(self.dim) == x_in.size(self.dim)) or \ if (self.dim < cond_value.ndim and cond_value.size(self.dim) == x_in.size(self.dim)) or \
(cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)): (cond_value.ndim < self.dim and cond_value.size(0) == x_in.size(self.dim)):
new_cond_item[cond_key] = window.get_tensor(cond_value, device) new_cond_item[cond_key] = window.get_tensor(cond_value, device)
# Handle audio_embed (temporal dim is 1) # Handle audio_embed (temporal dim is 1)
@ -224,6 +275,7 @@ class IndexListContextHandler(ContextHandlerABC):
return context_windows return context_windows
def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]): def execute(self, calc_cond_batch: Callable, model: BaseModel, conds: list[list[dict]], x_in: torch.Tensor, timestep: torch.Tensor, model_options: dict[str]):
self._model = model
self.set_step(timestep, model_options) self.set_step(timestep, model_options)
context_windows = self.get_context_windows(model, x_in, model_options) context_windows = self.get_context_windows(model, x_in, model_options)
enumerated_context_windows = list(enumerate(context_windows)) enumerated_context_windows = list(enumerate(context_windows))

View File

@ -209,3 +209,39 @@ def stochastic_round_quantize_nvfp4_by_block(x, per_tensor_scale, pad_16x, seed=
output_block[i:i + slice_size].copy_(block) output_block[i:i + slice_size].copy_(block)
return output_fp4, to_blocked(output_block, flatten=False) return output_fp4, to_blocked(output_block, flatten=False)
def stochastic_round_quantize_mxfp8_by_block(x, pad_32x, seed=0):
def roundup(x_val, multiple):
return ((x_val + multiple - 1) // multiple) * multiple
if pad_32x:
rows, cols = x.shape
padded_rows = roundup(rows, 32)
padded_cols = roundup(cols, 32)
if padded_rows != rows or padded_cols != cols:
x = torch.nn.functional.pad(x, (0, padded_cols - cols, 0, padded_rows - rows))
F8_E4M3_MAX = 448.0
E8M0_BIAS = 127
BLOCK_SIZE = 32
rows, cols = x.shape
x_blocked = x.reshape(rows, -1, BLOCK_SIZE)
max_abs = torch.amax(torch.abs(x_blocked), dim=-1)
# E8M0 block scales (power-of-2 exponents)
scale_needed = torch.clamp(max_abs.float() / F8_E4M3_MAX, min=2**(-127))
exp_biased = torch.clamp(torch.ceil(torch.log2(scale_needed)).to(torch.int32) + E8M0_BIAS, 0, 254)
block_scales_e8m0 = exp_biased.to(torch.uint8)
zero_mask = (max_abs == 0)
block_scales_f32 = (block_scales_e8m0.to(torch.int32) << 23).view(torch.float32)
block_scales_f32 = torch.where(zero_mask, torch.ones_like(block_scales_f32), block_scales_f32)
# Scale per-block then stochastic round
data_scaled = (x_blocked.float() / block_scales_f32.unsqueeze(-1)).reshape(rows, cols)
output_fp8 = stochastic_rounding(data_scaled, torch.float8_e4m3fn, seed=seed)
block_scales_e8m0 = torch.where(zero_mask, torch.zeros_like(block_scales_e8m0), block_scales_e8m0)
return output_fp8, to_blocked(block_scales_e8m0, flatten=False).view(torch.float8_e8m0fnu)

View File

@ -136,16 +136,7 @@ class ResBlock(nn.Module):
ops.Linear(c_hidden, c), ops.Linear(c_hidden, c),
) )
self.gammas = nn.Parameter(torch.zeros(6), requires_grad=True) self.gammas = nn.Parameter(torch.zeros(6), requires_grad=False)
# Init weights
def _basic_init(module):
if isinstance(module, nn.Linear) or isinstance(module, nn.Conv2d):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
self.apply(_basic_init)
def _norm(self, x, norm): def _norm(self, x, norm):
return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) return norm(x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)

View File

@ -144,9 +144,9 @@ def apply_mod(tensor, m_mult, m_add=None, modulation_dims=None):
return tensor * m_mult return tensor * m_mult
else: else:
for d in modulation_dims: for d in modulation_dims:
tensor[:, d[0]:d[1]] *= m_mult[:, d[2]] tensor[:, d[0]:d[1]] *= m_mult[:, d[2]:d[2] + 1]
if m_add is not None: if m_add is not None:
tensor[:, d[0]:d[1]] += m_add[:, d[2]] tensor[:, d[0]:d[1]] += m_add[:, d[2]:d[2] + 1]
return tensor return tensor

View File

@ -44,6 +44,22 @@ class FluxParams:
txt_norm: bool = False txt_norm: bool = False
def invert_slices(slices, length):
sorted_slices = sorted(slices)
result = []
current = 0
for start, end in sorted_slices:
if current < start:
result.append((current, start))
current = max(current, end)
if current < length:
result.append((current, length))
return result
class Flux(nn.Module): class Flux(nn.Module):
""" """
Transformer model for flow matching on sequences. Transformer model for flow matching on sequences.
@ -138,6 +154,7 @@ class Flux(nn.Module):
y: Tensor, y: Tensor,
guidance: Tensor = None, guidance: Tensor = None,
control = None, control = None,
timestep_zero_index=None,
transformer_options={}, transformer_options={},
attn_mask: Tensor = None, attn_mask: Tensor = None,
) -> Tensor: ) -> Tensor:
@ -164,10 +181,6 @@ class Flux(nn.Module):
txt = self.txt_norm(txt) txt = self.txt_norm(txt)
txt = self.txt_in(txt) txt = self.txt_in(txt)
vec_orig = vec
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(vec_orig))
if "post_input" in patches: if "post_input" in patches:
for p in patches["post_input"]: for p in patches["post_input"]:
out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options}) out = p({"img": img, "txt": txt, "img_ids": img_ids, "txt_ids": txt_ids, "transformer_options": transformer_options})
@ -182,6 +195,24 @@ class Flux(nn.Module):
else: else:
pe = None pe = None
vec_orig = vec
txt_vec = vec
extra_kwargs = {}
if timestep_zero_index is not None:
modulation_dims = []
batch = vec.shape[0] // 2
vec_orig = vec_orig.reshape(2, batch, vec.shape[1]).movedim(0, 1)
invert = invert_slices(timestep_zero_index, img.shape[1])
for s in invert:
modulation_dims.append((s[0], s[1], 0))
for s in timestep_zero_index:
modulation_dims.append((s[0], s[1], 1))
extra_kwargs["modulation_dims_img"] = modulation_dims
txt_vec = vec[:batch]
if self.params.global_modulation:
vec = (self.double_stream_modulation_img(vec_orig), self.double_stream_modulation_txt(txt_vec))
blocks_replace = patches_replace.get("dit", {}) blocks_replace = patches_replace.get("dit", {})
transformer_options["total_blocks"] = len(self.double_blocks) transformer_options["total_blocks"] = len(self.double_blocks)
transformer_options["block_type"] = "double" transformer_options["block_type"] = "double"
@ -195,7 +226,8 @@ class Flux(nn.Module):
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask"), attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options")) transformer_options=args.get("transformer_options"),
**extra_kwargs)
return out return out
out = blocks_replace[("double_block", i)]({"img": img, out = blocks_replace[("double_block", i)]({"img": img,
@ -213,7 +245,8 @@ class Flux(nn.Module):
vec=vec, vec=vec,
pe=pe, pe=pe,
attn_mask=attn_mask, attn_mask=attn_mask,
transformer_options=transformer_options) transformer_options=transformer_options,
**extra_kwargs)
if control is not None: # Controlnet if control is not None: # Controlnet
control_i = control.get("input") control_i = control.get("input")
@ -230,6 +263,12 @@ class Flux(nn.Module):
if self.params.global_modulation: if self.params.global_modulation:
vec, _ = self.single_stream_modulation(vec_orig) vec, _ = self.single_stream_modulation(vec_orig)
extra_kwargs = {}
if timestep_zero_index is not None:
lambda a: 0 if a == 0 else a + txt.shape[1]
modulation_dims_combined = list(map(lambda x: (0 if x[0] == 0 else x[0] + txt.shape[1], x[1] + txt.shape[1], x[2]), modulation_dims))
extra_kwargs["modulation_dims"] = modulation_dims_combined
transformer_options["total_blocks"] = len(self.single_blocks) transformer_options["total_blocks"] = len(self.single_blocks)
transformer_options["block_type"] = "single" transformer_options["block_type"] = "single"
transformer_options["img_slice"] = [txt.shape[1], img.shape[1]] transformer_options["img_slice"] = [txt.shape[1], img.shape[1]]
@ -242,7 +281,8 @@ class Flux(nn.Module):
vec=args["vec"], vec=args["vec"],
pe=args["pe"], pe=args["pe"],
attn_mask=args.get("attn_mask"), attn_mask=args.get("attn_mask"),
transformer_options=args.get("transformer_options")) transformer_options=args.get("transformer_options"),
**extra_kwargs)
return out return out
out = blocks_replace[("single_block", i)]({"img": img, out = blocks_replace[("single_block", i)]({"img": img,
@ -253,7 +293,7 @@ class Flux(nn.Module):
{"original_block": block_wrap}) {"original_block": block_wrap})
img = out["img"] img = out["img"]
else: else:
img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options) img = block(img, vec=vec, pe=pe, attn_mask=attn_mask, transformer_options=transformer_options, **extra_kwargs)
if control is not None: # Controlnet if control is not None: # Controlnet
control_o = control.get("output") control_o = control.get("output")
@ -264,7 +304,11 @@ class Flux(nn.Module):
img = img[:, txt.shape[1] :, ...] img = img[:, txt.shape[1] :, ...]
img = self.final_layer(img, vec_orig) # (N, T, patch_size ** 2 * out_channels) extra_kwargs = {}
if timestep_zero_index is not None:
extra_kwargs["modulation_dims"] = modulation_dims
img = self.final_layer(img, vec_orig, **extra_kwargs) # (N, T, patch_size ** 2 * out_channels)
return img return img
def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}): def process_img(self, x, index=0, h_offset=0, w_offset=0, transformer_options={}):
@ -312,13 +356,16 @@ class Flux(nn.Module):
w_len = ((w_orig + (patch_size // 2)) // patch_size) w_len = ((w_orig + (patch_size // 2)) // patch_size)
img, img_ids = self.process_img(x, transformer_options=transformer_options) img, img_ids = self.process_img(x, transformer_options=transformer_options)
img_tokens = img.shape[1] img_tokens = img.shape[1]
timestep_zero_index = None
if ref_latents is not None: if ref_latents is not None:
ref_num_tokens = []
h = 0 h = 0
w = 0 w = 0
index = 0 index = 0
ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method) ref_latents_method = kwargs.get("ref_latents_method", self.params.default_ref_method)
timestep_zero = ref_latents_method == "index_timestep_zero"
for ref in ref_latents: for ref in ref_latents:
if ref_latents_method == "index": if ref_latents_method in ("index", "index_timestep_zero"):
index += self.params.ref_index_scale index += self.params.ref_index_scale
h_offset = 0 h_offset = 0
w_offset = 0 w_offset = 0
@ -339,9 +386,16 @@ class Flux(nn.Module):
h = max(h, ref.shape[-2] + h_offset) h = max(h, ref.shape[-2] + h_offset)
w = max(w, ref.shape[-1] + w_offset) w = max(w, ref.shape[-1] + w_offset)
kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset) kontext, kontext_ids = self.process_img(ref, index=index, h_offset=h_offset, w_offset=w_offset, transformer_options=transformer_options)
img = torch.cat([img, kontext], dim=1) img = torch.cat([img, kontext], dim=1)
img_ids = torch.cat([img_ids, kontext_ids], dim=1) img_ids = torch.cat([img_ids, kontext_ids], dim=1)
ref_num_tokens.append(kontext.shape[1])
if timestep_zero:
if index > 0:
timestep = torch.cat([timestep, timestep * 0], dim=0)
timestep_zero_index = [[img_tokens, img_ids.shape[1]]]
transformer_options = transformer_options.copy()
transformer_options["reference_image_num_tokens"] = ref_num_tokens
txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32) txt_ids = torch.zeros((bs, context.shape[1], len(self.params.axes_dim)), device=x.device, dtype=torch.float32)
@ -349,6 +403,6 @@ class Flux(nn.Module):
for i in self.params.txt_ids_dims: for i in self.params.txt_ids_dims:
txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32) txt_ids[:, :, i] = torch.linspace(0, context.shape[1] - 1, steps=context.shape[1], device=x.device, dtype=torch.float32)
out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, transformer_options, attn_mask=kwargs.get("attention_mask", None)) out = self.forward_orig(img, img_ids, context, txt_ids, timestep, y, guidance, control, timestep_zero_index=timestep_zero_index, transformer_options=transformer_options, attn_mask=kwargs.get("attention_mask", None))
out = out[:, :img_tokens] out = out[:, :img_tokens]
return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig] return rearrange(out, "b (h w) (c ph pw) -> b c (h ph) (w pw)", h=h_len, w=w_len, ph=self.patch_size, pw=self.patch_size)[:,:,:h_orig,:w_orig]

View File

@ -343,6 +343,7 @@ class CrossAttention(nn.Module):
k.reshape(b, s2, self.num_heads * self.head_dim), k.reshape(b, s2, self.num_heads * self.head_dim),
v, v,
heads=self.num_heads, heads=self.num_heads,
low_precision_attention=False,
) )
out = self.out_proj(x) out = self.out_proj(x)
@ -412,6 +413,7 @@ class Attention(nn.Module):
key.reshape(B, N, self.num_heads * self.head_dim), key.reshape(B, N, self.num_heads * self.head_dim),
value, value,
heads=self.num_heads, heads=self.num_heads,
low_precision_attention=False,
) )
x = self.out_proj(x) x = self.out_proj(x)

View File

@ -681,6 +681,33 @@ class LTXAVModel(LTXVModel):
additional_args["has_spatial_mask"] = has_spatial_mask additional_args["has_spatial_mask"] = has_spatial_mask
ax, a_latent_coords = self.a_patchifier.patchify(ax) ax, a_latent_coords = self.a_patchifier.patchify(ax)
# Inject reference audio for ID-LoRA in-context conditioning
ref_audio = kwargs.get("ref_audio", None)
ref_audio_seq_len = 0
if ref_audio is not None:
ref_tokens = ref_audio["tokens"].to(dtype=ax.dtype, device=ax.device)
if ref_tokens.shape[0] < ax.shape[0]:
ref_tokens = ref_tokens.expand(ax.shape[0], -1, -1)
ref_audio_seq_len = ref_tokens.shape[1]
B = ax.shape[0]
# Compute negative temporal positions matching ID-LoRA convention:
# offset by -(end_of_last_token + time_per_latent) so reference ends just before t=0
p = self.a_patchifier
tpl = p.hop_length * p.audio_latent_downsample_factor / p.sample_rate
ref_start = p._get_audio_latent_time_in_sec(0, ref_audio_seq_len, torch.float32, ax.device)
ref_end = p._get_audio_latent_time_in_sec(1, ref_audio_seq_len + 1, torch.float32, ax.device)
time_offset = ref_end[-1].item() + tpl
ref_start = (ref_start - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_end = (ref_end - time_offset).unsqueeze(0).expand(B, -1).unsqueeze(1)
ref_pos = torch.stack([ref_start, ref_end], dim=-1)
additional_args["ref_audio_seq_len"] = ref_audio_seq_len
additional_args["target_audio_seq_len"] = ax.shape[1]
ax = torch.cat([ref_tokens, ax], dim=1)
a_latent_coords = torch.cat([ref_pos.to(a_latent_coords), a_latent_coords], dim=2)
ax = self.audio_patchify_proj(ax) ax = self.audio_patchify_proj(ax)
# additional_args.update({"av_orig_shape": list(x.shape)}) # additional_args.update({"av_orig_shape": list(x.shape)})
@ -721,6 +748,14 @@ class LTXAVModel(LTXVModel):
# Prepare audio timestep # Prepare audio timestep
a_timestep = kwargs.get("a_timestep") a_timestep = kwargs.get("a_timestep")
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
if ref_audio_seq_len > 0 and a_timestep is not None:
# Reference tokens must have timestep=0, expand scalar/1D timestep to per-token so ref=0 and target=sigma.
target_len = kwargs.get("target_audio_seq_len")
if a_timestep.dim() <= 1:
a_timestep = a_timestep.view(-1, 1).expand(batch_size, target_len)
ref_ts = torch.zeros(batch_size, ref_audio_seq_len, *a_timestep.shape[2:], device=a_timestep.device, dtype=a_timestep.dtype)
a_timestep = torch.cat([ref_ts, a_timestep], dim=1)
if a_timestep is not None: if a_timestep is not None:
a_timestep_scaled = a_timestep * self.timestep_scale_multiplier a_timestep_scaled = a_timestep * self.timestep_scale_multiplier
a_timestep_flat = a_timestep_scaled.flatten() a_timestep_flat = a_timestep_scaled.flatten()
@ -955,6 +990,13 @@ class LTXAVModel(LTXVModel):
v_embedded_timestep = embedded_timestep[0] v_embedded_timestep = embedded_timestep[0]
a_embedded_timestep = embedded_timestep[1] a_embedded_timestep = embedded_timestep[1]
# Trim reference audio tokens before unpatchification
ref_audio_seq_len = kwargs.get("ref_audio_seq_len", 0)
if ref_audio_seq_len > 0:
ax = ax[:, ref_audio_seq_len:]
if a_embedded_timestep.shape[1] > 1:
a_embedded_timestep = a_embedded_timestep[:, ref_audio_seq_len:]
# Expand compressed video timestep if needed # Expand compressed video timestep if needed
if isinstance(v_embedded_timestep, CompressedTimestep): if isinstance(v_embedded_timestep, CompressedTimestep):
v_embedded_timestep = v_embedded_timestep.expand() v_embedded_timestep = v_embedded_timestep.expand()

View File

@ -23,6 +23,11 @@ class CausalConv3d(nn.Module):
self.in_channels = in_channels self.in_channels = in_channels
self.out_channels = out_channels self.out_channels = out_channels
if isinstance(stride, int):
self.time_stride = stride
else:
self.time_stride = stride[0]
kernel_size = (kernel_size, kernel_size, kernel_size) kernel_size = (kernel_size, kernel_size, kernel_size)
self.time_kernel_size = kernel_size[0] self.time_kernel_size = kernel_size[0]
@ -58,16 +63,25 @@ class CausalConv3d(nn.Module):
pieces = [ cached, x ] pieces = [ cached, x ]
if is_end and not causal: if is_end and not causal:
pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1))) pieces.append(x[:, :, -1:, :, :].repeat((1, 1, (self.time_kernel_size - 1) // 2, 1, 1)))
input_length = sum([piece.shape[2] for piece in pieces])
cache_length = (self.time_kernel_size - self.time_stride) + ((input_length - self.time_kernel_size) % self.time_stride)
needs_caching = not is_end needs_caching = not is_end
if needs_caching and x.shape[2] >= self.time_kernel_size - 1: if needs_caching and cache_length == 0:
self.temporal_cache_state[tid] = (x[:, :, :0, :, :], False)
needs_caching = False needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) if needs_caching and x.shape[2] >= cache_length:
needs_caching = False
self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
x = torch.cat(pieces, dim=2) x = torch.cat(pieces, dim=2)
del pieces
del cached
if needs_caching: if needs_caching:
self.temporal_cache_state[tid] = (x[:, :, -(self.time_kernel_size - 1):, :, :], False) self.temporal_cache_state[tid] = (x[:, :, -cache_length:, :, :], False)
elif is_end:
self.temporal_cache_state[tid] = (None, True)
return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :] return self.conv(x) if x.shape[2] >= self.time_kernel_size else x[:, :, :0, :, :]

View File

@ -11,6 +11,7 @@ from .causal_conv3d import CausalConv3d
from .pixel_norm import PixelNorm from .pixel_norm import PixelNorm
from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings from ..model import PixArtAlphaCombinedTimestepSizeEmbeddings
import comfy.ops import comfy.ops
import comfy.model_management
from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed from comfy.ldm.modules.diffusionmodules.model import torch_cat_if_needed
ops = comfy.ops.disable_weight_init ops = comfy.ops.disable_weight_init
@ -232,10 +233,7 @@ class Encoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
def forward_orig(self, sample: torch.FloatTensor) -> torch.FloatTensor: def _forward_chunk(self, sample: torch.FloatTensor) -> Optional[torch.FloatTensor]:
r"""The forward method of the `Encoder` class."""
sample = patchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
sample = self.conv_in(sample) sample = self.conv_in(sample)
checkpoint_fn = ( checkpoint_fn = (
@ -246,10 +244,14 @@ class Encoder(nn.Module):
for down_block in self.down_blocks: for down_block in self.down_blocks:
sample = checkpoint_fn(down_block)(sample) sample = checkpoint_fn(down_block)(sample)
if sample is None or sample.shape[2] == 0:
return None
sample = self.conv_norm_out(sample) sample = self.conv_norm_out(sample)
sample = self.conv_act(sample) sample = self.conv_act(sample)
sample = self.conv_out(sample) sample = self.conv_out(sample)
if sample is None or sample.shape[2] == 0:
return None
if self.latent_log_var == "uniform": if self.latent_log_var == "uniform":
last_channel = sample[:, -1:, ...] last_channel = sample[:, -1:, ...]
@ -281,9 +283,35 @@ class Encoder(nn.Module):
return sample return sample
def forward_orig(self, sample: torch.FloatTensor, device=None) -> torch.FloatTensor:
r"""The forward method of the `Encoder` class."""
max_chunk_size = get_max_chunk_size(sample.device if device is None else device) * 2 # encoder is more memory-efficient than decoder
frame_size = sample[:, :, :1, :, :].numel() * sample.element_size()
frame_size = int(frame_size * (self.conv_in.out_channels / self.conv_in.in_channels))
outputs = []
samples = [sample[:, :, :1, :, :]]
if sample.shape[2] > 1:
chunk_t = max(2, max_chunk_size // frame_size)
if chunk_t < 4:
chunk_t = 2
elif chunk_t < 8:
chunk_t = 4
else:
chunk_t = (chunk_t // 8) * 8
samples += list(torch.split(sample[:, :, 1:, :, :], chunk_t, dim=2))
for chunk_idx, chunk in enumerate(samples):
if chunk_idx == len(samples) - 1:
mark_conv3d_ended(self)
chunk = patchify(chunk, patch_size_hw=self.patch_size, patch_size_t=1).to(device=device)
output = self._forward_chunk(chunk)
if output is not None:
outputs.append(output)
return torch_cat_if_needed(outputs, dim=2)
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
#No encoder support so just flag the end so it doesnt use the cache.
mark_conv3d_ended(self)
try: try:
return self.forward_orig(*args, **kwargs) return self.forward_orig(*args, **kwargs)
finally: finally:
@ -296,7 +324,23 @@ class Encoder(nn.Module):
module.temporal_cache_state.pop(tid, None) module.temporal_cache_state.pop(tid, None)
MAX_CHUNK_SIZE=(128 * 1024 ** 2) MIN_VRAM_FOR_CHUNK_SCALING = 6 * 1024 ** 3
MAX_VRAM_FOR_CHUNK_SCALING = 24 * 1024 ** 3
MIN_CHUNK_SIZE = 32 * 1024 ** 2
MAX_CHUNK_SIZE = 128 * 1024 ** 2
def get_max_chunk_size(device: torch.device) -> int:
total_memory = comfy.model_management.get_total_memory(dev=device)
if total_memory <= MIN_VRAM_FOR_CHUNK_SCALING:
return MIN_CHUNK_SIZE
if total_memory >= MAX_VRAM_FOR_CHUNK_SCALING:
return MAX_CHUNK_SIZE
interp = (total_memory - MIN_VRAM_FOR_CHUNK_SCALING) / (
MAX_VRAM_FOR_CHUNK_SCALING - MIN_VRAM_FOR_CHUNK_SCALING
)
return int(MIN_CHUNK_SIZE + interp * (MAX_CHUNK_SIZE - MIN_CHUNK_SIZE))
class Decoder(nn.Module): class Decoder(nn.Module):
r""" r"""
@ -456,6 +500,17 @@ class Decoder(nn.Module):
self.gradient_checkpointing = False self.gradient_checkpointing = False
# Precompute output scale factors: (channels, (t_scale, h_scale, w_scale), t_offset)
ts, hs, ws, to = 1, 1, 1, 0
for block in self.up_blocks:
if isinstance(block, DepthToSpaceUpsample):
ts *= block.stride[0]
hs *= block.stride[1]
ws *= block.stride[2]
if block.stride[0] > 1:
to = to * block.stride[0] + 1
self._output_scale = (out_channels // (patch_size ** 2), (ts, hs * patch_size, ws * patch_size), to)
self.timestep_conditioning = timestep_conditioning self.timestep_conditioning = timestep_conditioning
if timestep_conditioning: if timestep_conditioning:
@ -477,11 +532,62 @@ class Decoder(nn.Module):
) )
# def forward(self, sample: torch.FloatTensor, target_shape) -> torch.FloatTensor: def decode_output_shape(self, input_shape):
c, (ts, hs, ws), to = self._output_scale
return (input_shape[0], c, input_shape[2] * ts - to, input_shape[3] * hs, input_shape[4] * ws)
def run_up(self, idx, sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size):
sample = sample_ref[0]
sample_ref[0] = None
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
t = sample.shape[2]
output_buffer[:, :, output_offset[0]:output_offset[0] + t].copy_(sample)
output_offset[0] += t
return
up_block = self.up_blocks[idx]
if ended:
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0:
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + max_chunk_size - 1) // max_chunk_size
if num_chunks == 1:
# when we are not chunking, detach our x so the callee can free it as soon as they are done
next_sample_ref = [sample]
del sample
self.run_up(idx + 1, next_sample_ref, ended, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
return
else:
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
self.run_up(idx + 1, [sample1], ended and chunk_idx == len(samples) - 1, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
def forward_orig( def forward_orig(
self, self,
sample: torch.FloatTensor, sample: torch.FloatTensor,
timestep: Optional[torch.Tensor] = None, timestep: Optional[torch.Tensor] = None,
output_buffer: Optional[torch.Tensor] = None,
) -> torch.FloatTensor: ) -> torch.FloatTensor:
r"""The forward method of the `Decoder` class.""" r"""The forward method of the `Decoder` class."""
batch_size = sample.shape[0] batch_size = sample.shape[0]
@ -496,6 +602,7 @@ class Decoder(nn.Module):
) )
timestep_shift_scale = None timestep_shift_scale = None
scaled_timestep = None
if self.timestep_conditioning: if self.timestep_conditioning:
assert ( assert (
timestep is not None timestep is not None
@ -523,48 +630,18 @@ class Decoder(nn.Module):
) )
timestep_shift_scale = ada_values.unbind(dim=1) timestep_shift_scale = ada_values.unbind(dim=1)
output = [] if output_buffer is None:
output_buffer = torch.empty(
self.decode_output_shape(sample.shape),
dtype=sample.dtype, device=comfy.model_management.intermediate_device(),
)
output_offset = [0]
def run_up(idx, sample, ended): max_chunk_size = get_max_chunk_size(sample.device)
if idx >= len(self.up_blocks):
sample = self.conv_norm_out(sample)
if timestep_shift_scale is not None:
shift, scale = timestep_shift_scale
sample = sample * (1 + scale) + shift
sample = self.conv_act(sample)
if ended:
mark_conv3d_ended(self.conv_out)
sample = self.conv_out(sample, causal=self.causal)
if sample is not None and sample.shape[2] > 0:
output.append(sample)
return
up_block = self.up_blocks[idx] self.run_up(0, [sample], True, timestep_shift_scale, scaled_timestep, checkpoint_fn, output_buffer, output_offset, max_chunk_size)
if (ended):
mark_conv3d_ended(up_block)
if self.timestep_conditioning and isinstance(up_block, UNetMidBlock3D):
sample = checkpoint_fn(up_block)(
sample, causal=self.causal, timestep=scaled_timestep
)
else:
sample = checkpoint_fn(up_block)(sample, causal=self.causal)
if sample is None or sample.shape[2] == 0: return output_buffer
return
total_bytes = sample.numel() * sample.element_size()
num_chunks = (total_bytes + MAX_CHUNK_SIZE - 1) // MAX_CHUNK_SIZE
samples = torch.chunk(sample, chunks=num_chunks, dim=2)
for chunk_idx, sample1 in enumerate(samples):
run_up(idx + 1, sample1, ended and chunk_idx == len(samples) - 1)
run_up(0, sample, True)
sample = torch.cat(output, dim=2)
sample = unpatchify(sample, patch_size_hw=self.patch_size, patch_size_t=1)
return sample
def forward(self, *args, **kwargs): def forward(self, *args, **kwargs):
try: try:
@ -688,12 +765,25 @@ class SpaceToDepthDownsample(nn.Module):
causal=True, causal=True,
spatial_padding_mode=spatial_padding_mode, spatial_padding_mode=spatial_padding_mode,
) )
self.temporal_cache_state = {}
def forward(self, x, causal: bool = True): def forward(self, x, causal: bool = True):
if self.stride[0] == 2: tid = threading.get_ident()
cached, pad_first, cached_x, cached_input = self.temporal_cache_state.get(tid, (None, True, None, None))
if cached_input is not None:
x = torch_cat_if_needed([cached_input, x], dim=2)
cached_input = None
if self.stride[0] == 2 and pad_first:
x = torch.cat( x = torch.cat(
[x[:, :, :1, :, :], x], dim=2 [x[:, :, :1, :, :], x], dim=2
) # duplicate first frames for padding ) # duplicate first frames for padding
pad_first = False
if x.shape[2] < self.stride[0]:
cached_input = x
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return None
# skip connection # skip connection
x_in = rearrange( x_in = rearrange(
@ -708,15 +798,26 @@ class SpaceToDepthDownsample(nn.Module):
# conv # conv
x = self.conv(x, causal=causal) x = self.conv(x, causal=causal)
x = rearrange( if self.stride[0] == 2 and x.shape[2] == 1:
x, if cached_x is not None:
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w", x = torch_cat_if_needed([cached_x, x], dim=2)
p1=self.stride[0], cached_x = None
p2=self.stride[1], else:
p3=self.stride[2], cached_x = x
) x = None
x = x + x_in if x is not None:
x = rearrange(
x,
"b c (d p1) (h p2) (w p3) -> b (c p1 p2 p3) d h w",
p1=self.stride[0],
p2=self.stride[1],
p3=self.stride[2],
)
cached = add_exchange_cache(x, cached, x_in, dim=2)
self.temporal_cache_state[tid] = (cached, pad_first, cached_x, cached_input)
return x return x
@ -1049,6 +1150,8 @@ class processor(nn.Module):
return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x) return (x - self.get_buffer("mean-of-means").view(1, -1, 1, 1, 1).to(x)) / self.get_buffer("std-of-means").view(1, -1, 1, 1, 1).to(x)
class VideoVAE(nn.Module): class VideoVAE(nn.Module):
comfy_has_chunked_io = True
def __init__(self, version=0, config=None): def __init__(self, version=0, config=None):
super().__init__() super().__init__()
@ -1191,14 +1294,15 @@ class VideoVAE(nn.Module):
} }
return config return config
def encode(self, x): def encode(self, x, device=None):
frames_count = x.shape[2] x = x[:, :, :max(1, 1 + ((x.shape[2] - 1) // 8) * 8), :, :]
if ((frames_count - 1) % 8) != 0: means, logvar = torch.chunk(self.encoder(x, device=device), 2, dim=1)
raise ValueError("Invalid number of frames: Encode input must have 1 + 8 * x frames (e.g., 1, 9, 17, ...). Please check your input.")
means, logvar = torch.chunk(self.encoder(x), 2, dim=1)
return self.per_channel_statistics.normalize(means) return self.per_channel_statistics.normalize(means)
def decode(self, x): def decode_output_shape(self, input_shape):
return self.decoder.decode_output_shape(input_shape)
def decode(self, x, output_buffer=None):
if self.timestep_conditioning: #TODO: seed if self.timestep_conditioning: #TODO: seed
x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x x = torch.randn_like(x) * self.decode_noise_scale + (1.0 - self.decode_noise_scale) * x
return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep) return self.decoder(self.per_channel_statistics.un_normalize(x), timestep=self.decode_timestep, output_buffer=output_buffer)

View File

@ -99,7 +99,7 @@ class Resample(nn.Module):
else: else:
self.resample = nn.Identity() self.resample = nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
b, c, t, h, w = x.size() b, c, t, h, w = x.size()
if self.mode == 'upsample3d': if self.mode == 'upsample3d':
if feat_cache is not None: if feat_cache is not None:
@ -109,22 +109,7 @@ class Resample(nn.Module):
feat_idx[0] += 1 feat_idx[0] += 1
else: else:
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] != 'Rep':
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
if cache_x.shape[2] < 2 and feat_cache[
idx] is not None and feat_cache[idx] == 'Rep':
cache_x = torch.cat([
torch.zeros_like(cache_x).to(cache_x.device),
cache_x
],
dim=2)
if feat_cache[idx] == 'Rep': if feat_cache[idx] == 'Rep':
x = self.time_conv(x) x = self.time_conv(x)
else: else:
@ -145,19 +130,24 @@ class Resample(nn.Module):
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
if feat_cache[idx] is None: if feat_cache[idx] is None:
feat_cache[idx] = x.clone() feat_cache[idx] = x
feat_idx[0] += 1
else: else:
cache_x = x[:, :, -1:, :, :].clone() cache_x = x[:, :, -1:, :, :]
# if cache_x.shape[2] < 2 and feat_cache[idx] is not None and feat_cache[idx]!='Rep':
# # cache last frame of last two chunk
# cache_x = torch.cat([feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(cache_x.device), cache_x], dim=2)
x = self.time_conv( x = self.time_conv(
torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2)) torch.cat([feat_cache[idx][:, :, -1:, :, :], x], 2))
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1
deferred_x = feat_cache[idx + 1]
if deferred_x is not None:
x = torch.cat([deferred_x, x], 2)
feat_cache[idx + 1] = None
if x.shape[2] == 1 and not final:
feat_cache[idx + 1] = x
x = None
feat_idx[0] += 2
return x return x
@ -177,19 +167,12 @@ class ResidualBlock(nn.Module):
self.shortcut = CausalConv3d(in_dim, out_dim, 1) \ self.shortcut = CausalConv3d(in_dim, out_dim, 1) \
if in_dim != out_dim else nn.Identity() if in_dim != out_dim else nn.Identity()
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
old_x = x old_x = x
for layer in self.residual: for layer in self.residual:
if isinstance(layer, CausalConv3d) and feat_cache is not None: if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, cache_list=feat_cache, cache_idx=idx) x = layer(x, cache_list=feat_cache, cache_idx=idx)
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -213,7 +196,7 @@ class AttentionBlock(nn.Module):
self.proj = ops.Conv2d(dim, dim, 1) self.proj = ops.Conv2d(dim, dim, 1)
self.optimized_attention = vae_attention() self.optimized_attention = vae_attention()
def forward(self, x): def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
identity = x identity = x
b, c, t, h, w = x.size() b, c, t, h, w = x.size()
x = rearrange(x, 'b c t h w -> (b t) c h w') x = rearrange(x, 'b c t h w -> (b t) c h w')
@ -283,17 +266,10 @@ class Encoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(), RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, z_dim, 3, padding=1)) CausalConv3d(out_dim, z_dim, 3, padding=1))
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0], final=False):
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx]) x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -303,14 +279,16 @@ class Encoder3d(nn.Module):
## downsamples ## downsamples
for layer in self.downsamples: for layer in self.downsamples:
if feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x = layer(x, feat_cache, feat_idx, final=final)
if x is None:
return None
else: else:
x = layer(x) x = layer(x)
## middle ## middle
for layer in self.middle: for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x = layer(x, feat_cache, feat_idx, final=final)
else: else:
x = layer(x) x = layer(x)
@ -318,14 +296,7 @@ class Encoder3d(nn.Module):
for layer in self.head: for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None: if isinstance(layer, CausalConv3d) and feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx]) x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -389,18 +360,48 @@ class Decoder3d(nn.Module):
RMS_norm(out_dim, images=False), nn.SiLU(), RMS_norm(out_dim, images=False), nn.SiLU(),
CausalConv3d(out_dim, output_channels, 3, padding=1)) CausalConv3d(out_dim, output_channels, 3, padding=1))
def run_up(self, layer_idx, x_ref, feat_cache, feat_idx, out_chunks):
x = x_ref[0]
x_ref[0] = None
if layer_idx >= len(self.upsamples):
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None:
cache_x = x[:, :, -CACHE_T:, :, :]
x = layer(x, feat_cache[feat_idx[0]])
feat_cache[feat_idx[0]] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
out_chunks.append(x)
return
layer = self.upsamples[layer_idx]
if feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
if isinstance(layer, Resample) and layer.mode == 'upsample3d' and x.shape[2] > 2:
for frame_idx in range(0, x.shape[2], 2):
self.run_up(
layer_idx + 1,
[x[:, :, frame_idx:frame_idx + 2, :, :]],
feat_cache,
feat_idx.copy(),
out_chunks,
)
del x
return
next_x_ref = [x]
del x
self.run_up(layer_idx + 1, next_x_ref, feat_cache, feat_idx, out_chunks)
def forward(self, x, feat_cache=None, feat_idx=[0]): def forward(self, x, feat_cache=None, feat_idx=[0]):
## conv1 ## conv1
if feat_cache is not None: if feat_cache is not None:
idx = feat_idx[0] idx = feat_idx[0]
cache_x = x[:, :, -CACHE_T:, :, :].clone() cache_x = x[:, :, -CACHE_T:, :, :]
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = self.conv1(x, feat_cache[idx]) x = self.conv1(x, feat_cache[idx])
feat_cache[idx] = cache_x feat_cache[idx] = cache_x
feat_idx[0] += 1 feat_idx[0] += 1
@ -409,42 +410,21 @@ class Decoder3d(nn.Module):
## middle ## middle
for layer in self.middle: for layer in self.middle:
if isinstance(layer, ResidualBlock) and feat_cache is not None:
x = layer(x, feat_cache, feat_idx)
else:
x = layer(x)
## upsamples
for layer in self.upsamples:
if feat_cache is not None: if feat_cache is not None:
x = layer(x, feat_cache, feat_idx) x = layer(x, feat_cache, feat_idx)
else: else:
x = layer(x) x = layer(x)
## head out_chunks = []
for layer in self.head:
if isinstance(layer, CausalConv3d) and feat_cache is not None: self.run_up(0, [x], feat_cache, feat_idx, out_chunks)
idx = feat_idx[0] return out_chunks
cache_x = x[:, :, -CACHE_T:, :, :].clone()
if cache_x.shape[2] < 2 and feat_cache[idx] is not None:
# cache last frame of last two chunk
cache_x = torch.cat([
feat_cache[idx][:, :, -1, :, :].unsqueeze(2).to(
cache_x.device), cache_x
],
dim=2)
x = layer(x, feat_cache[idx])
feat_cache[idx] = cache_x
feat_idx[0] += 1
else:
x = layer(x)
return x
def count_conv3d(model): def count_cache_layers(model):
count = 0 count = 0
for m in model.modules(): for m in model.modules():
if isinstance(m, CausalConv3d): if isinstance(m, CausalConv3d) or (isinstance(m, Resample) and m.mode == 'downsample3d'):
count += 1 count += 1
return count return count
@ -482,11 +462,12 @@ class WanVAE(nn.Module):
conv_idx = [0] conv_idx = [0]
## cache ## cache
t = x.shape[2] t = x.shape[2]
iter_ = 1 + (t - 1) // 4 t = 1 + ((t - 1) // 4) * 4
iter_ = 1 + (t - 1) // 2
feat_map = None feat_map = None
if iter_ > 1: if iter_ > 1:
feat_map = [None] * count_conv3d(self.encoder) feat_map = [None] * count_cache_layers(self.encoder)
## 对encode输入的x按时间拆分为1、4、4、4.... ## 对encode输入的x按时间拆分为1、2、2、2....(总帧数先按4N+1向下取整)
for i in range(iter_): for i in range(iter_):
conv_idx = [0] conv_idx = [0]
if i == 0: if i == 0:
@ -496,20 +477,23 @@ class WanVAE(nn.Module):
feat_idx=conv_idx) feat_idx=conv_idx)
else: else:
out_ = self.encoder( out_ = self.encoder(
x[:, :, 1 + 4 * (i - 1):1 + 4 * i, :, :], x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map, feat_cache=feat_map,
feat_idx=conv_idx) feat_idx=conv_idx,
final=(i == (iter_ - 1)))
if out_ is None:
continue
out = torch.cat([out, out_], 2) out = torch.cat([out, out_], 2)
mu, log_var = self.conv1(out).chunk(2, dim=1) mu, log_var = self.conv1(out).chunk(2, dim=1)
return mu return mu
def decode(self, z): def decode(self, z):
conv_idx = [0]
# z: [b,c,t,h,w] # z: [b,c,t,h,w]
iter_ = z.shape[2] iter_ = 1 + z.shape[2] // 2
feat_map = None feat_map = None
if iter_ > 1: if iter_ > 1:
feat_map = [None] * count_conv3d(self.decoder) feat_map = [None] * count_cache_layers(self.decoder)
x = self.conv2(z) x = self.conv2(z)
for i in range(iter_): for i in range(iter_):
conv_idx = [0] conv_idx = [0]
@ -520,8 +504,8 @@ class WanVAE(nn.Module):
feat_idx=conv_idx) feat_idx=conv_idx)
else: else:
out_ = self.decoder( out_ = self.decoder(
x[:, :, i:i + 1, :, :], x[:, :, 1 + 2 * (i - 1):1 + 2 * i, :, :],
feat_cache=feat_map, feat_cache=feat_map,
feat_idx=conv_idx) feat_idx=conv_idx)
out = torch.cat([out, out_], 2) out += out_
return out return torch.cat(out, 2)

View File

@ -1,9 +1,71 @@
import math import math
import ctypes
import threading
import dataclasses
import torch import torch
from typing import NamedTuple from typing import NamedTuple
from comfy.quant_ops import QuantizedTensor from comfy.quant_ops import QuantizedTensor
class TensorFileSlice(NamedTuple):
file_ref: object
thread_id: int
offset: int
size: int
def read_tensor_file_slice_into(tensor, destination):
if isinstance(tensor, QuantizedTensor):
if not isinstance(destination, QuantizedTensor):
return False
if tensor._layout_cls != destination._layout_cls:
return False
if not read_tensor_file_slice_into(tensor._qdata, destination._qdata):
return False
dst_orig_dtype = destination._params.orig_dtype
destination._params.copy_from(tensor._params, non_blocking=False)
destination._params = dataclasses.replace(destination._params, orig_dtype=dst_orig_dtype)
return True
info = getattr(tensor.untyped_storage(), "_comfy_tensor_file_slice", None)
if info is None:
return False
file_obj = info.file_ref
if (destination.device.type != "cpu"
or file_obj is None
or threading.get_ident() != info.thread_id
or destination.numel() * destination.element_size() < info.size
or tensor.numel() * tensor.element_size() != info.size
or tensor.storage_offset() != 0
or not tensor.is_contiguous()):
return False
if info.size == 0:
return True
buf_type = ctypes.c_ubyte * info.size
view = memoryview(buf_type.from_address(destination.data_ptr()))
try:
file_obj.seek(info.offset)
done = 0
while done < info.size:
try:
n = file_obj.readinto(view[done:])
except OSError:
return False
if n <= 0:
return False
done += n
return True
finally:
view.release()
class TensorGeometry(NamedTuple): class TensorGeometry(NamedTuple):
shape: any shape: any
dtype: torch.dtype dtype: torch.dtype

View File

@ -21,6 +21,7 @@ import comfy.ldm.hunyuan3dv2_1.hunyuandit
import torch import torch
import logging import logging
import comfy.ldm.lightricks.av_model import comfy.ldm.lightricks.av_model
import comfy.context_windows
from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep from comfy.ldm.modules.diffusionmodules.openaimodel import UNetModel, Timestep
from comfy.ldm.cascade.stage_c import StageC from comfy.ldm.cascade.stage_c import StageC
from comfy.ldm.cascade.stage_b import StageB from comfy.ldm.cascade.stage_b import StageB
@ -285,6 +286,12 @@ class BaseModel(torch.nn.Module):
return data return data
return None return None
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
"""Override in subclasses to handle model-specific cond slicing for context windows.
Return a sliced cond object, or None to fall through to default handling.
Use comfy.context_windows.slice_cond() for common cases."""
return None
def extra_conds(self, **kwargs): def extra_conds(self, **kwargs):
out = {} out = {}
concat_cond = self.concat_cond(**kwargs) concat_cond = self.concat_cond(**kwargs)
@ -930,9 +937,10 @@ class LongCatImage(Flux):
transformer_options = transformer_options.copy() transformer_options = transformer_options.copy()
rope_opts = transformer_options.get("rope_options", {}) rope_opts = transformer_options.get("rope_options", {})
rope_opts = dict(rope_opts) rope_opts = dict(rope_opts)
pe_len = float(c_crossattn.shape[1]) if c_crossattn is not None else 512.0
rope_opts.setdefault("shift_t", 1.0) rope_opts.setdefault("shift_t", 1.0)
rope_opts.setdefault("shift_y", 512.0) rope_opts.setdefault("shift_y", pe_len)
rope_opts.setdefault("shift_x", 512.0) rope_opts.setdefault("shift_x", pe_len)
transformer_options["rope_options"] = rope_opts transformer_options["rope_options"] = rope_opts
return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs) return super()._apply_model(x, t, c_concat, c_crossattn, control, transformer_options, **kwargs)
@ -1053,6 +1061,10 @@ class LTXAV(BaseModel):
if guide_attention_entries is not None: if guide_attention_entries is not None:
out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries) out['guide_attention_entries'] = comfy.conds.CONDConstant(guide_attention_entries)
ref_audio = kwargs.get("ref_audio", None)
if ref_audio is not None:
out['ref_audio'] = comfy.conds.CONDConstant(ref_audio)
return out return out
def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs): def process_timestep(self, timestep, x, denoise_mask=None, audio_denoise_mask=None, **kwargs):
@ -1375,6 +1387,11 @@ class WAN21_Vace(WAN21):
out['vace_strength'] = comfy.conds.CONDConstant(vace_strength) out['vace_strength'] = comfy.conds.CONDConstant(vace_strength)
return out return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "vace_context":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=3, retain_index_list=retain_index_list)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN21_Camera(WAN21): class WAN21_Camera(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.CameraWanModel)
@ -1427,6 +1444,11 @@ class WAN21_HuMo(WAN21):
return out return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_Animate(WAN21): class WAN22_Animate(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_animate.AnimateWanModel)
@ -1444,6 +1466,13 @@ class WAN22_Animate(WAN21):
out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents)) out['pose_latents'] = comfy.conds.CONDRegular(self.process_latent_in(pose_latents))
return out return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "face_pixel_values":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_scale=4, temporal_offset=1)
if cond_key == "pose_latents":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22_S2V(WAN21): class WAN22_S2V(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel_S2V)
@ -1480,6 +1509,11 @@ class WAN22_S2V(WAN21):
out['reference_motion'] = reference_motion.shape out['reference_motion'] = reference_motion.shape
return out return out
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
if cond_key == "audio_embed":
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=1)
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
class WAN22(WAN21): class WAN22(WAN21):
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None): def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel) super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.WanModel)

View File

@ -270,10 +270,15 @@ try:
except: except:
OOM_EXCEPTION = Exception OOM_EXCEPTION = Exception
try:
ACCELERATOR_ERROR = torch.AcceleratorError
except AttributeError:
ACCELERATOR_ERROR = RuntimeError
def is_oom(e): def is_oom(e):
if isinstance(e, OOM_EXCEPTION): if isinstance(e, OOM_EXCEPTION):
return True return True
if isinstance(e, torch.AcceleratorError) and getattr(e, 'error_code', None) == 2: if isinstance(e, ACCELERATOR_ERROR) and (getattr(e, 'error_code', None) == 2 or "out of memory" in str(e).lower()):
discard_cuda_async_error() discard_cuda_async_error()
return True return True
return False return False
@ -395,7 +400,7 @@ try:
if args.use_split_cross_attention == False and args.use_quad_cross_attention == False: if args.use_split_cross_attention == False and args.use_quad_cross_attention == False:
if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton. if aotriton_supported(arch): # AMD efficient attention implementation depends on aotriton.
if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much if torch_version_numeric >= (2, 7): # works on 2.6 but doesn't actually seem to improve much
if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1151"]): # TODO: more arches, TODO: gfx950 if any((a in arch) for a in ["gfx90a", "gfx942", "gfx950", "gfx1100", "gfx1101", "gfx1150", "gfx1151"]): # TODO: more arches, TODO: gfx950
ENABLE_PYTORCH_ATTENTION = True ENABLE_PYTORCH_ATTENTION = True
if rocm_version >= (7, 0): if rocm_version >= (7, 0):
if any((a in arch) for a in ["gfx1200", "gfx1201"]): if any((a in arch) for a in ["gfx1200", "gfx1201"]):
@ -500,6 +505,28 @@ def module_size(module):
module_mem += t.nbytes module_mem += t.nbytes
return module_mem return module_mem
def module_mmap_residency(module, free=False):
mmap_touched_mem = 0
module_mem = 0
bounced_mmaps = set()
sd = module.state_dict()
for k in sd:
t = sd[k]
module_mem += t.nbytes
storage = t._qdata.untyped_storage() if isinstance(t, comfy.quant_ops.QuantizedTensor) else t.untyped_storage()
if not getattr(storage, "_comfy_tensor_mmap_touched", False):
continue
mmap_touched_mem += t.nbytes
if not free:
continue
storage._comfy_tensor_mmap_touched = False
mmap_obj = storage._comfy_tensor_mmap_refs[0]
if mmap_obj in bounced_mmaps:
continue
mmap_obj.bounce()
bounced_mmaps.add(mmap_obj)
return mmap_touched_mem, module_mem
class LoadedModel: class LoadedModel:
def __init__(self, model): def __init__(self, model):
self._set_model(model) self._set_model(model)
@ -514,6 +541,7 @@ class LoadedModel:
if model.parent is not None: if model.parent is not None:
self._parent_model = weakref.ref(model.parent) self._parent_model = weakref.ref(model.parent)
self._patcher_finalizer = weakref.finalize(model, self._switch_parent) self._patcher_finalizer = weakref.finalize(model, self._switch_parent)
self._patcher_finalizer.atexit = False
def _switch_parent(self): def _switch_parent(self):
model = self._parent_model() model = self._parent_model()
@ -527,6 +555,9 @@ class LoadedModel:
def model_memory(self): def model_memory(self):
return self.model.model_size() return self.model.model_size()
def model_mmap_residency(self, free=False):
return self.model.model_mmap_residency(free=free)
def model_loaded_memory(self): def model_loaded_memory(self):
return self.model.loaded_size() return self.model.loaded_size()
@ -557,6 +588,7 @@ class LoadedModel:
self.real_model = weakref.ref(real_model) self.real_model = weakref.ref(real_model)
self.model_finalizer = weakref.finalize(real_model, cleanup_models) self.model_finalizer = weakref.finalize(real_model, cleanup_models)
self.model_finalizer.atexit = False
return real_model return real_model
def should_reload_model(self, force_patch_weights=False): def should_reload_model(self, force_patch_weights=False):
@ -628,7 +660,7 @@ def extra_reserved_memory():
def minimum_inference_memory(): def minimum_inference_memory():
return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory() return (1024 * 1024 * 1024) * 0.8 + extra_reserved_memory()
def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_required=0): def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, pins_required=0, ram_required=0):
cleanup_models_gc() cleanup_models_gc()
unloaded_model = [] unloaded_model = []
can_unload = [] can_unload = []
@ -641,13 +673,14 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i)) can_unload.append((-shift_model.model_offloaded_memory(), sys.getrefcount(shift_model.model), shift_model.model_memory(), i))
shift_model.currently_used = False shift_model.currently_used = False
for x in sorted(can_unload): can_unload_sorted = sorted(can_unload)
for x in can_unload_sorted:
i = x[-1] i = x[-1]
memory_to_free = 1e32 memory_to_free = 1e32
ram_to_free = 1e32 pins_to_free = 1e32
if not DISABLE_SMART_MEMORY: if not DISABLE_SMART_MEMORY:
memory_to_free = memory_required - get_free_memory(device) memory_to_free = memory_required - get_free_memory(device)
ram_to_free = ram_required - get_free_ram() pins_to_free = pins_required - get_free_ram()
if current_loaded_models[i].model.is_dynamic() and for_dynamic: if current_loaded_models[i].model.is_dynamic() and for_dynamic:
#don't actually unload dynamic models for the sake of other dynamic models #don't actually unload dynamic models for the sake of other dynamic models
#as that works on-demand. #as that works on-demand.
@ -656,9 +689,18 @@ def free_memory(memory_required, device, keep_loaded=[], for_dynamic=False, ram_
if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free): if memory_to_free > 0 and current_loaded_models[i].model_unload(memory_to_free):
logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}") logging.debug(f"Unloading {current_loaded_models[i].model.model.__class__.__name__}")
unloaded_model.append(i) unloaded_model.append(i)
if ram_to_free > 0: if pins_to_free > 0:
logging.debug(f"PIN Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(pins_to_free)
for x in can_unload_sorted:
i = x[-1]
ram_to_free = ram_required - psutil.virtual_memory().available
if ram_to_free <= 0 and i not in unloaded_model:
continue
resident_memory, _ = current_loaded_models[i].model_mmap_residency(free=True)
if resident_memory > 0:
logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}") logging.debug(f"RAM Unloading {current_loaded_models[i].model.model.__class__.__name__}")
current_loaded_models[i].model.partially_unload_ram(ram_to_free)
for i in sorted(unloaded_model, reverse=True): for i in sorted(unloaded_model, reverse=True):
unloaded_models.append(current_loaded_models.pop(i)) unloaded_models.append(current_loaded_models.pop(i))
@ -724,17 +766,27 @@ def load_models_gpu(models, memory_required=0, force_patch_weights=False, minimu
total_memory_required = {} total_memory_required = {}
total_pins_required = {}
total_ram_required = {} total_ram_required = {}
for loaded_model in models_to_load: for loaded_model in models_to_load:
total_memory_required[loaded_model.device] = total_memory_required.get(loaded_model.device, 0) + loaded_model.model_memory_required(loaded_model.device) device = loaded_model.device
#x2, one to make sure the OS can fit the model for loading in disk cache, and for us to do any pinning we total_memory_required[device] = total_memory_required.get(device, 0) + loaded_model.model_memory_required(device)
#want to do. resident_memory, model_memory = loaded_model.model_mmap_residency()
#FIXME: This should subtract off the to_load current pin consumption. pinned_memory = loaded_model.model.pinned_memory_size()
total_ram_required[loaded_model.device] = total_ram_required.get(loaded_model.device, 0) + loaded_model.model_memory() * 2 #FIXME: This can over-free the pins as it budgets to pin the entire model. We should
#make this JIT to keep as much pinned as possible.
pins_required = model_memory - pinned_memory
ram_required = model_memory - resident_memory
total_pins_required[device] = total_pins_required.get(device, 0) + pins_required
total_ram_required[device] = total_ram_required.get(device, 0) + ram_required
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
free_memory(total_memory_required[device] * 1.1 + extra_mem, device, for_dynamic=free_for_dynamic, ram_required=total_ram_required[device]) free_memory(total_memory_required[device] * 1.1 + extra_mem,
device,
for_dynamic=free_for_dynamic,
pins_required=total_pins_required[device],
ram_required=total_ram_required[device])
for device in total_memory_required: for device in total_memory_required:
if device != torch.device("cpu"): if device != torch.device("cpu"):
@ -1000,6 +1052,12 @@ def intermediate_device():
else: else:
return torch.device("cpu") return torch.device("cpu")
def intermediate_dtype():
if args.fp16_intermediates:
return torch.float16
else:
return torch.float32
def vae_device(): def vae_device():
if args.cpu_vae: if args.cpu_vae:
return torch.device("cpu") return torch.device("cpu")
@ -1220,6 +1278,11 @@ def cast_to_gathered(tensors, r, non_blocking=False, stream=None):
dest_view = dest_views.pop(0) dest_view = dest_views.pop(0)
if tensor is None: if tensor is None:
continue continue
if comfy.memory_management.read_tensor_file_slice_into(tensor, dest_view):
continue
storage = tensor._qdata.untyped_storage() if isinstance(tensor, comfy.quant_ops.QuantizedTensor) else tensor.untyped_storage()
if hasattr(storage, "_comfy_tensor_mmap_touched"):
storage._comfy_tensor_mmap_touched = True
dest_view.copy_(tensor, non_blocking=non_blocking) dest_view.copy_(tensor, non_blocking=non_blocking)
@ -1275,7 +1338,7 @@ def discard_cuda_async_error():
b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device()) b = torch.tensor([1], dtype=torch.uint8, device=get_torch_device())
_ = a + b _ = a + b
synchronize() synchronize()
except torch.AcceleratorError: except RuntimeError:
#Dump it! We already know about it from the synchronous return #Dump it! We already know about it from the synchronous return
pass pass
@ -1657,6 +1720,19 @@ def supports_nvfp4_compute(device=None):
return True return True
def supports_mxfp8_compute(device=None):
if not is_nvidia():
return False
if torch_version_numeric < (2, 10):
return False
props = torch.cuda.get_device_properties(device)
if props.major < 10:
return False
return True
def extended_fp16_support(): def extended_fp16_support():
# TODO: check why some models work with fp16 on newer torch versions but not on older # TODO: check why some models work with fp16 on newer torch versions but not on older
if torch_version_numeric < (2, 7): if torch_version_numeric < (2, 7):

View File

@ -297,6 +297,9 @@ class ModelPatcher:
self.size = comfy.model_management.module_size(self.model) self.size = comfy.model_management.module_size(self.model)
return self.size return self.size
def model_mmap_residency(self, free=False):
return comfy.model_management.module_mmap_residency(self.model, free=free)
def get_ram_usage(self): def get_ram_usage(self):
return self.model_size() return self.model_size()
@ -1063,6 +1066,10 @@ class ModelPatcher:
return self.model.model_loaded_weight_memory - current_used return self.model.model_loaded_weight_memory - current_used
def pinned_memory_size(self):
# Pinned memory pressure tracking is only implemented for DynamicVram loading
return 0
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload):
pass pass
@ -1653,6 +1660,16 @@ class ModelPatcherDynamic(ModelPatcher):
return freed return freed
def pinned_memory_size(self):
total = 0
loading = self._load_list(for_dynamic=True)
for x in loading:
_, _, _, _, m, _ = x
pin = comfy.pinned_memory.get_pin(m)
if pin is not None:
total += pin.numel() * pin.element_size()
return total
def partially_unload_ram(self, ram_to_unload): def partially_unload_ram(self, ram_to_unload):
loading = self._load_list(for_dynamic=True, default_device=self.offload_device) loading = self._load_list(for_dynamic=True, default_device=self.offload_device)
for x in loading: for x in loading:

View File

@ -306,10 +306,40 @@ class CastWeightBiasOp:
bias_function = [] bias_function = []
class disable_weight_init: class disable_weight_init:
@staticmethod
def _lazy_load_from_state_dict(module, state_dict, prefix, local_metadata,
missing_keys, unexpected_keys, weight_shape,
bias_shape=None):
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False)
prefix_len = len(prefix)
for k, v in state_dict.items():
key = k[prefix_len:]
if key == "weight":
if not assign_to_params_buffers:
v = v.clone()
module.weight = torch.nn.Parameter(v, requires_grad=False)
elif bias_shape is not None and key == "bias" and v is not None:
if not assign_to_params_buffers:
v = v.clone()
module.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
if module.weight is None:
module.weight = torch.nn.Parameter(torch.zeros(weight_shape), requires_grad=False)
missing_keys.append(prefix + "weight")
if bias_shape is not None and module.bias is None and getattr(module, "comfy_need_lazy_init_bias", False):
module.bias = torch.nn.Parameter(torch.zeros(bias_shape), requires_grad=False)
missing_keys.append(prefix + "bias")
class Linear(torch.nn.Linear, CastWeightBiasOp): class Linear(torch.nn.Linear, CastWeightBiasOp):
def __init__(self, in_features, out_features, bias=True, device=None, dtype=None): def __init__(self, in_features, out_features, bias=True, device=None, dtype=None):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: # don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
super().__init__(in_features, out_features, bias, device, dtype) super().__init__(in_features, out_features, bias, device, dtype)
return return
@ -330,32 +360,21 @@ class disable_weight_init:
def _load_from_state_dict(self, state_dict, prefix, local_metadata, def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs): strict, missing_keys, unexpected_keys, error_msgs):
if not comfy.model_management.WINDOWS or not comfy.memory_management.aimdo_enabled: if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Linear._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict, return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs) missing_keys, unexpected_keys, error_msgs)
assign_to_params_buffers = local_metadata.get("assign_to_params_buffers", False) disable_weight_init._lazy_load_from_state_dict(
prefix_len = len(prefix) self,
for k,v in state_dict.items(): state_dict,
if k[prefix_len:] == "weight": prefix,
if not assign_to_params_buffers: local_metadata,
v = v.clone() missing_keys,
self.weight = torch.nn.Parameter(v, requires_grad=False) unexpected_keys,
elif k[prefix_len:] == "bias" and v is not None: weight_shape=(self.in_features, self.out_features),
if not assign_to_params_buffers: bias_shape=(self.out_features,),
v = v.clone() )
self.bias = torch.nn.Parameter(v, requires_grad=False)
else:
unexpected_keys.append(k)
#Reconcile default construction of the weight if its missing.
if self.weight is None:
v = torch.zeros(self.in_features, self.out_features)
self.weight = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"weight")
if self.bias is None and self.comfy_need_lazy_init_bias:
v = torch.zeros(self.out_features,)
self.bias = torch.nn.Parameter(v, requires_grad=False)
missing_keys.append(prefix+"bias")
def reset_parameters(self): def reset_parameters(self):
@ -547,6 +566,53 @@ class disable_weight_init:
return super().forward(*args, **kwargs) return super().forward(*args, **kwargs)
class Embedding(torch.nn.Embedding, CastWeightBiasOp): class Embedding(torch.nn.Embedding, CastWeightBiasOp):
def __init__(self, num_embeddings, embedding_dim, padding_idx=None, max_norm=None,
norm_type=2.0, scale_grad_by_freq=False, sparse=False, _weight=None,
_freeze=False, device=None, dtype=None):
# don't trust subclasses that BYO state dict loader to call us.
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
super().__init__(num_embeddings, embedding_dim, padding_idx, max_norm,
norm_type, scale_grad_by_freq, sparse, _weight,
_freeze, device, dtype)
return
torch.nn.Module.__init__(self)
self.num_embeddings = num_embeddings
self.embedding_dim = embedding_dim
self.padding_idx = padding_idx
self.max_norm = max_norm
self.norm_type = norm_type
self.scale_grad_by_freq = scale_grad_by_freq
self.sparse = sparse
# Keep shape/dtype visible for module introspection without reserving storage.
embedding_dtype = dtype if dtype is not None else torch.get_default_dtype()
self.weight = torch.nn.Parameter(
torch.empty((num_embeddings, embedding_dim), device="meta", dtype=embedding_dtype),
requires_grad=False,
)
self.bias = None
self.weight_comfy_model_dtype = dtype
def _load_from_state_dict(self, state_dict, prefix, local_metadata,
strict, missing_keys, unexpected_keys, error_msgs):
if (not comfy.model_management.WINDOWS
or not comfy.memory_management.aimdo_enabled
or type(self)._load_from_state_dict is not disable_weight_init.Embedding._load_from_state_dict):
return super()._load_from_state_dict(state_dict, prefix, local_metadata, strict,
missing_keys, unexpected_keys, error_msgs)
disable_weight_init._lazy_load_from_state_dict(
self,
state_dict,
prefix,
local_metadata,
missing_keys,
unexpected_keys,
weight_shape=(self.num_embeddings, self.embedding_dim),
)
def reset_parameters(self): def reset_parameters(self):
self.bias = None self.bias = None
return None return None
@ -710,6 +776,71 @@ from .quant_ops import (
) )
class QuantLinearFunc(torch.autograd.Function):
"""Custom autograd function for quantized linear: quantized forward, compute_dtype backward.
Handles any input rank by flattening to 2D for matmul and restoring shape after.
"""
@staticmethod
def forward(ctx, input_float, weight, bias, layout_type, input_scale, compute_dtype):
input_shape = input_float.shape
inp = input_float.detach().flatten(0, -2) # zero-cost view to 2D
# Quantize input (same as inference path)
if layout_type is not None:
q_input = QuantizedTensor.from_float(inp, layout_type, scale=input_scale)
else:
q_input = inp
w = weight.detach() if weight.requires_grad else weight
b = bias.detach() if bias is not None and bias.requires_grad else bias
output = torch.nn.functional.linear(q_input, w, b)
# Restore original input shape
if len(input_shape) > 2:
output = output.unflatten(0, input_shape[:-1])
ctx.save_for_backward(input_float, weight)
ctx.input_shape = input_shape
ctx.has_bias = bias is not None
ctx.compute_dtype = compute_dtype
ctx.weight_requires_grad = weight.requires_grad
return output
@staticmethod
@torch.autograd.function.once_differentiable
def backward(ctx, grad_output):
input_float, weight = ctx.saved_tensors
compute_dtype = ctx.compute_dtype
grad_2d = grad_output.flatten(0, -2).to(compute_dtype)
# Dequantize weight to compute dtype for backward matmul
if isinstance(weight, QuantizedTensor):
weight_f = weight.dequantize().to(compute_dtype)
else:
weight_f = weight.to(compute_dtype)
# grad_input = grad_output @ weight
grad_input = torch.mm(grad_2d, weight_f)
if len(ctx.input_shape) > 2:
grad_input = grad_input.unflatten(0, ctx.input_shape[:-1])
# grad_weight (only if weight requires grad, typically frozen for quantized training)
grad_weight = None
if ctx.weight_requires_grad:
input_f = input_float.flatten(0, -2).to(compute_dtype)
grad_weight = torch.mm(grad_2d.t(), input_f)
# grad_bias
grad_bias = None
if ctx.has_bias:
grad_bias = grad_2d.sum(dim=0)
return grad_input, grad_weight, grad_bias, None, None, None
def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]): def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_precision_mm=False, disabled=[]):
class MixedPrecisionOps(manual_cast): class MixedPrecisionOps(manual_cast):
_quant_config = quant_config _quant_config = quant_config
@ -801,6 +932,22 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
orig_shape=(self.out_features, self.in_features), orig_shape=(self.out_features, self.in_features),
) )
elif self.quant_format == "mxfp8":
# MXFP8: E8M0 block scales stored as uint8 in safetensors
block_scale = self._load_scale_param(state_dict, prefix, "weight_scale", device, manually_loaded_keys,
dtype=torch.uint8)
if block_scale is None:
raise ValueError(f"Missing MXFP8 block scales for layer {layer_name}")
block_scale = block_scale.view(torch.float8_e8m0fnu)
params = layout_cls.Params(
scale=block_scale,
orig_dtype=MixedPrecisionOps._compute_dtype,
orig_shape=(self.out_features, self.in_features),
)
elif self.quant_format == "nvfp4": elif self.quant_format == "nvfp4":
# NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale) # NVFP4: tensor_scale (weight_scale_2) + block_scale (weight_scale)
tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys) tensor_scale = self._load_scale_param(state_dict, prefix, "weight_scale_2", device, manually_loaded_keys)
@ -888,10 +1035,37 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
#If cast needs to apply lora, it should be done in the compute dtype #If cast needs to apply lora, it should be done in the compute dtype
compute_dtype = input.dtype compute_dtype = input.dtype
if (getattr(self, 'layout_type', None) is not None and _use_quantized = (
getattr(self, 'layout_type', None) is not None and
not isinstance(input, QuantizedTensor) and not self._full_precision_mm and not isinstance(input, QuantizedTensor) and not self._full_precision_mm and
not getattr(self, 'comfy_force_cast_weights', False) and not getattr(self, 'comfy_force_cast_weights', False) and
len(self.weight_function) == 0 and len(self.bias_function) == 0): len(self.weight_function) == 0 and len(self.bias_function) == 0
)
# Training path: quantized forward with compute_dtype backward via autograd function
if (input.requires_grad and _use_quantized):
weight, bias, offload_stream = cast_bias_weight(
self,
input,
offloadable=True,
compute_dtype=compute_dtype,
want_requant=True
)
scale = getattr(self, 'input_scale', None)
if scale is not None:
scale = comfy.model_management.cast_to_device(scale, input.device, None)
output = QuantLinearFunc.apply(
input, weight, bias, self.layout_type, scale, compute_dtype
)
uncast_bias_weight(self, weight, bias, offload_stream)
return output
# Inference path (unchanged)
if _use_quantized:
# Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others) # Reshape 3D tensors to 2D for quantization (needed for NVFP4 and others)
input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input input_reshaped = input.reshape(-1, input_shape[2]) if input.ndim == 3 else input
@ -939,7 +1113,10 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
for key, param in self._parameters.items(): for key, param in self._parameters.items():
if param is None: if param is None:
continue continue
self.register_parameter(key, torch.nn.Parameter(fn(param), requires_grad=False)) p = fn(param)
if p.is_inference():
p = p.clone()
self.register_parameter(key, torch.nn.Parameter(p, requires_grad=False))
for key, buf in self._buffers.items(): for key, buf in self._buffers.items():
if buf is not None: if buf is not None:
self._buffers[key] = fn(buf) self._buffers[key] = fn(buf)
@ -950,12 +1127,15 @@ def mixed_precision_ops(quant_config={}, compute_dtype=torch.bfloat16, full_prec
def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None): def pick_operations(weight_dtype, compute_dtype, load_device=None, disable_fast_fp8=False, fp8_optimizations=False, model_config=None):
fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular fp8_compute = comfy.model_management.supports_fp8_compute(load_device) # TODO: if we support more ops this needs to be more granular
nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device) nvfp4_compute = comfy.model_management.supports_nvfp4_compute(load_device)
mxfp8_compute = comfy.model_management.supports_mxfp8_compute(load_device)
if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config: if model_config and hasattr(model_config, 'quant_config') and model_config.quant_config:
logging.info("Using mixed precision operations") logging.info("Using mixed precision operations")
disabled = set() disabled = set()
if not nvfp4_compute: if not nvfp4_compute:
disabled.add("nvfp4") disabled.add("nvfp4")
if not mxfp8_compute:
disabled.add("mxfp8")
if not fp8_compute: if not fp8_compute:
disabled.add("float8_e4m3fn") disabled.add("float8_e4m3fn")
disabled.add("float8_e5m2") disabled.add("float8_e5m2")

View File

@ -1,6 +1,7 @@
import torch
import comfy.model_management import comfy.model_management
import comfy.memory_management import comfy.memory_management
import comfy_aimdo.host_buffer
import comfy_aimdo.torch
from comfy.cli_args import args from comfy.cli_args import args
@ -12,18 +13,31 @@ def pin_memory(module):
return return
#FIXME: This is a RAM cache trigger event #FIXME: This is a RAM cache trigger event
size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ]) size = comfy.memory_management.vram_aligned_size([ module.weight, module.bias ])
pin = torch.empty((size,), dtype=torch.uint8)
if comfy.model_management.pin_memory(pin): if comfy.model_management.MAX_PINNED_MEMORY <= 0 or (comfy.model_management.TOTAL_PINNED_MEMORY + size) > comfy.model_management.MAX_PINNED_MEMORY:
module._pin = pin
else:
module.pin_failed = True module.pin_failed = True
return False return False
try:
hostbuf = comfy_aimdo.host_buffer.HostBuffer(size)
except RuntimeError:
module.pin_failed = True
return False
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)
module._pin_hostbuf = hostbuf
comfy.model_management.TOTAL_PINNED_MEMORY += size
return True return True
def unpin_memory(module): def unpin_memory(module):
if get_pin(module) is None: if get_pin(module) is None:
return 0 return 0
size = module._pin.numel() * module._pin.element_size() size = module._pin.numel() * module._pin.element_size()
comfy.model_management.unpin_memory(module._pin)
comfy.model_management.TOTAL_PINNED_MEMORY -= size
if comfy.model_management.TOTAL_PINNED_MEMORY < 0:
comfy.model_management.TOTAL_PINNED_MEMORY = 0
del module._pin del module._pin
del module._pin_hostbuf
return size return size

View File

@ -43,6 +43,18 @@ except ImportError as e:
def get_layout_class(name): def get_layout_class(name):
return None return None
_CK_MXFP8_AVAILABLE = False
if _CK_AVAILABLE:
try:
from comfy_kitchen.tensor import TensorCoreMXFP8Layout as _CKMxfp8Layout
_CK_MXFP8_AVAILABLE = True
except ImportError:
logging.warning("comfy_kitchen does not support MXFP8, please update comfy_kitchen.")
if not _CK_MXFP8_AVAILABLE:
class _CKMxfp8Layout:
pass
import comfy.float import comfy.float
# ============================================================================== # ==============================================================================
@ -84,6 +96,31 @@ class _TensorCoreFP8LayoutBase(_CKFp8Layout):
return qdata, params return qdata, params
class TensorCoreMXFP8Layout(_CKMxfp8Layout):
@classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
if tensor.dim() != 2:
raise ValueError(f"MXFP8 requires 2D tensor, got {tensor.dim()}D")
orig_dtype = tensor.dtype
orig_shape = tuple(tensor.shape)
padded_shape = cls.get_padded_shape(orig_shape)
needs_padding = padded_shape != orig_shape
if stochastic_rounding > 0:
qdata, block_scale = comfy.float.stochastic_round_quantize_mxfp8_by_block(tensor, pad_32x=needs_padding, seed=stochastic_rounding)
else:
qdata, block_scale = ck.quantize_mxfp8(tensor, pad_32x=needs_padding)
params = cls.Params(
scale=block_scale,
orig_dtype=orig_dtype,
orig_shape=orig_shape,
)
return qdata, params
class TensorCoreNVFP4Layout(_CKNvfp4Layout): class TensorCoreNVFP4Layout(_CKNvfp4Layout):
@classmethod @classmethod
def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False): def quantize(cls, tensor, scale=None, stochastic_rounding=0, inplace_ops=False):
@ -137,6 +174,8 @@ register_layout_class("TensorCoreFP8Layout", TensorCoreFP8Layout)
register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout) register_layout_class("TensorCoreFP8E4M3Layout", TensorCoreFP8E4M3Layout)
register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout) register_layout_class("TensorCoreFP8E5M2Layout", TensorCoreFP8E5M2Layout)
register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout) register_layout_class("TensorCoreNVFP4Layout", TensorCoreNVFP4Layout)
if _CK_MXFP8_AVAILABLE:
register_layout_class("TensorCoreMXFP8Layout", TensorCoreMXFP8Layout)
QUANT_ALGOS = { QUANT_ALGOS = {
"float8_e4m3fn": { "float8_e4m3fn": {
@ -157,6 +196,14 @@ QUANT_ALGOS = {
}, },
} }
if _CK_MXFP8_AVAILABLE:
QUANT_ALGOS["mxfp8"] = {
"storage_t": torch.float8_e4m3fn,
"parameters": {"weight_scale", "input_scale"},
"comfy_tensor_layout": "TensorCoreMXFP8Layout",
"group_size": 32,
}
# ============================================================================== # ==============================================================================
# Re-exports for backward compatibility # Re-exports for backward compatibility

View File

@ -8,12 +8,12 @@ import comfy.nested_tensor
def prepare_noise_inner(latent_image, generator, noise_inds=None): def prepare_noise_inner(latent_image, generator, noise_inds=None):
if noise_inds is None: if noise_inds is None:
return torch.randn(latent_image.size(), dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") return torch.randn(latent_image.size(), dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
unique_inds, inverse = np.unique(noise_inds, return_inverse=True) unique_inds, inverse = np.unique(noise_inds, return_inverse=True)
noises = [] noises = []
for i in range(unique_inds[-1]+1): for i in range(unique_inds[-1]+1):
noise = torch.randn([1] + list(latent_image.size())[1:], dtype=latent_image.dtype, layout=latent_image.layout, generator=generator, device="cpu") noise = torch.randn([1] + list(latent_image.size())[1:], dtype=torch.float32, layout=latent_image.layout, generator=generator, device="cpu").to(dtype=latent_image.dtype)
if i in unique_inds: if i in unique_inds:
noises.append(noise) noises.append(noise)
noises = [noises[i] for i in inverse] noises = [noises[i] for i in inverse]
@ -64,10 +64,10 @@ def sample(model, noise, steps, cfg, sampler_name, scheduler, positive, negative
sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options) sampler = comfy.samplers.KSampler(model, steps=steps, device=model.load_device, sampler=sampler_name, scheduler=scheduler, denoise=denoise, model_options=model.model_options)
samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = sampler.sample(noise, positive, negative, cfg=cfg, latent_image=latent_image, start_step=start_step, last_step=last_step, force_full_denoise=force_full_denoise, denoise_mask=noise_mask, sigmas=sigmas, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return samples return samples
def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None): def sample_custom(model, noise, cfg, sampler, sigmas, positive, negative, latent_image, noise_mask=None, callback=None, disable_pbar=False, seed=None):
samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed) samples = comfy.samplers.sample(model, noise, positive, negative, cfg, model.load_device, sampler, sigmas, model_options=model.model_options, latent_image=latent_image, denoise_mask=noise_mask, callback=callback, disable_pbar=disable_pbar, seed=seed)
samples = samples.to(comfy.model_management.intermediate_device()) samples = samples.to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return samples return samples

View File

@ -985,8 +985,8 @@ class CFGGuider:
self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options) self.inner_model, self.conds, self.loaded_models = comfy.sampler_helpers.prepare_sampling(self.model_patcher, noise.shape, self.conds, self.model_options)
device = self.model_patcher.load_device device = self.model_patcher.load_device
noise = noise.to(device) noise = noise.to(device=device, dtype=torch.float32)
latent_image = latent_image.to(device) latent_image = latent_image.to(device=device, dtype=torch.float32)
sigmas = sigmas.to(device) sigmas = sigmas.to(device)
cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype()) cast_to_load_options(self.model_options, device=device, dtype=self.model_patcher.model_dtype())
@ -1028,6 +1028,7 @@ class CFGGuider:
denoise_mask, _ = comfy.utils.pack_latents(denoise_masks) denoise_mask, _ = comfy.utils.pack_latents(denoise_masks)
else: else:
denoise_mask = denoise_masks[0] denoise_mask = denoise_masks[0]
denoise_mask = denoise_mask.float()
self.conds = {} self.conds = {}
for k in self.original_conds: for k in self.original_conds:

View File

@ -455,7 +455,7 @@ class VAE:
self.output_channels = 3 self.output_channels = 3
self.pad_channel_value = None self.pad_channel_value = None
self.process_input = lambda image: image * 2.0 - 1.0 self.process_input = lambda image: image * 2.0 - 1.0
self.process_output = lambda image: torch.clamp((image + 1.0) / 2.0, min=0.0, max=1.0) self.process_output = lambda image: image.add_(1.0).div_(2.0).clamp_(0.0, 1.0)
self.working_dtypes = [torch.bfloat16, torch.float32] self.working_dtypes = [torch.bfloat16, torch.float32]
self.disable_offload = False self.disable_offload = False
self.not_video = False self.not_video = False
@ -871,13 +871,16 @@ class VAE:
pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value) pixels = torch.nn.functional.pad(pixels, (0, self.output_channels - pixels.shape[-1]), mode=mode, value=value)
return pixels return pixels
def vae_output_dtype(self):
return model_management.intermediate_dtype()
def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16): def decode_tiled_(self, samples, tile_x=64, tile_y=64, overlap = 16):
steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap) steps = samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x, tile_y, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x // 2, tile_y * 2, overlap)
steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += samples.shape[0] * comfy.utils.get_tiled_scale_steps(samples.shape[3], samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
output = self.process_output( output = self.process_output(
(comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + (comfy.utils.tiled_scale(samples, decode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) + comfy.utils.tiled_scale(samples, decode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = self.upscale_ratio, output_device=self.output_device, pbar = pbar) +
@ -887,16 +890,16 @@ class VAE:
def decode_tiled_1d(self, samples, tile_x=256, overlap=32): def decode_tiled_1d(self, samples, tile_x=256, overlap=32):
if samples.ndim == 3: if samples.ndim == 3:
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
else: else:
og_shape = samples.shape og_shape = samples.shape
samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1)) samples = samples.reshape((og_shape[0], og_shape[1] * og_shape[2], -1))
decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.reshape((-1, og_shape[1], og_shape[2], a.shape[-1])).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, output_device=self.output_device))
def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)): def decode_tiled_3d(self, samples, tile_t=999, tile_x=32, tile_y=32, overlap=(1, 8, 8)):
decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).float() decode_fn = lambda a: self.first_stage_model.decode(a.to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device)) return self.process_output(comfy.utils.tiled_scale_multidim(samples, decode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.upscale_ratio, out_channels=self.output_channels, index_formulas=self.upscale_index_formula, output_device=self.output_device))
def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64): def encode_tiled_(self, pixel_samples, tile_x=512, tile_y=512, overlap = 64):
@ -905,7 +908,7 @@ class VAE:
steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap) steps += pixel_samples.shape[0] * comfy.utils.get_tiled_scale_steps(pixel_samples.shape[3], pixel_samples.shape[2], tile_x * 2, tile_y // 2, overlap)
pbar = comfy.utils.ProgressBar(steps) pbar = comfy.utils.ProgressBar(steps)
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples = comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x, tile_y, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x * 2, tile_y // 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar) samples += comfy.utils.tiled_scale(pixel_samples, encode_fn, tile_x // 2, tile_y * 2, overlap, upscale_amount = (1/self.downscale_ratio), out_channels=self.latent_channels, output_device=self.output_device, pbar=pbar)
@ -914,7 +917,7 @@ class VAE:
def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048): def encode_tiled_1d(self, samples, tile_x=256 * 2048, overlap=64 * 2048):
if self.latent_dim == 1: if self.latent_dim == 1:
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
out_channels = self.latent_channels out_channels = self.latent_channels
upscale_amount = 1 / self.downscale_ratio upscale_amount = 1 / self.downscale_ratio
else: else:
@ -923,7 +926,7 @@ class VAE:
tile_x = tile_x // extra_channel_size tile_x = tile_x // extra_channel_size
overlap = overlap // extra_channel_size overlap = overlap // extra_channel_size
upscale_amount = 1 / self.downscale_ratio upscale_amount = 1 / self.downscale_ratio
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).reshape(1, out_channels, -1).to(dtype=self.vae_output_dtype())
out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device) out = comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_x,), overlap=overlap, upscale_amount=upscale_amount, out_channels=out_channels, output_device=self.output_device)
if self.latent_dim == 1: if self.latent_dim == 1:
@ -932,7 +935,7 @@ class VAE:
return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1) return out.reshape(samples.shape[0], self.latent_channels, extra_channel_size, -1)
def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)): def encode_tiled_3d(self, samples, tile_t=9999, tile_x=512, tile_y=512, overlap=(1, 64, 64)):
encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).float() encode_fn = lambda a: self.first_stage_model.encode((self.process_input(a)).to(self.vae_dtype).to(self.device)).to(dtype=self.vae_output_dtype())
return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device) return comfy.utils.tiled_scale_multidim(samples, encode_fn, tile=(tile_t, tile_x, tile_y), overlap=overlap, upscale_amount=self.downscale_ratio, out_channels=self.latent_channels, downscale=True, index_formulas=self.downscale_index_formula, output_device=self.output_device)
def decode(self, samples_in, vae_options={}): def decode(self, samples_in, vae_options={}):
@ -948,12 +951,23 @@ class VAE:
batch_number = int(free_memory / memory_used) batch_number = int(free_memory / memory_used)
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
# Pre-allocate output for VAEs that support direct buffer writes
preallocated = False
if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
pixel_samples = torch.empty(self.first_stage_model.decode_output_shape(samples_in.shape), device=self.output_device, dtype=self.vae_output_dtype())
preallocated = True
for x in range(0, samples_in.shape[0], batch_number): for x in range(0, samples_in.shape[0], batch_number):
samples = samples_in[x:x+batch_number].to(self.vae_dtype).to(self.device) samples = samples_in[x:x + batch_number].to(device=self.device, dtype=self.vae_dtype)
out = self.process_output(self.first_stage_model.decode(samples, **vae_options).to(self.output_device).float()) if preallocated:
if pixel_samples is None: self.first_stage_model.decode(samples, output_buffer=pixel_samples[x:x+batch_number], **vae_options)
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device) else:
pixel_samples[x:x+batch_number] = out out = self.first_stage_model.decode(samples, **vae_options).to(device=self.output_device, dtype=self.vae_output_dtype(), copy=True)
if pixel_samples is None:
pixel_samples = torch.empty((samples_in.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
pixel_samples[x:x+batch_number].copy_(out)
del out
self.process_output(pixel_samples[x:x+batch_number])
except Exception as e: except Exception as e:
model_management.raise_non_oom(e) model_management.raise_non_oom(e)
logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.") logging.warning("Warning: Ran out of memory when regular VAE decoding, retrying with tiled VAE decoding.")
@ -964,6 +978,7 @@ class VAE:
do_tile = True do_tile = True
if do_tile: if do_tile:
comfy.model_management.soft_empty_cache()
dims = samples_in.ndim - 2 dims = samples_in.ndim - 2
if dims == 1 or self.extra_1d_channel is not None: if dims == 1 or self.extra_1d_channel is not None:
pixel_samples = self.decode_tiled_1d(samples_in) pixel_samples = self.decode_tiled_1d(samples_in)
@ -1024,10 +1039,15 @@ class VAE:
batch_number = max(1, batch_number) batch_number = max(1, batch_number)
samples = None samples = None
for x in range(0, pixel_samples.shape[0], batch_number): for x in range(0, pixel_samples.shape[0], batch_number):
pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype).to(self.device) pixels_in = self.process_input(pixel_samples[x:x + batch_number]).to(self.vae_dtype)
out = self.first_stage_model.encode(pixels_in).to(self.output_device).float() if getattr(self.first_stage_model, 'comfy_has_chunked_io', False):
out = self.first_stage_model.encode(pixels_in, device=self.device)
else:
pixels_in = pixels_in.to(self.device)
out = self.first_stage_model.encode(pixels_in)
out = out.to(self.output_device).to(dtype=self.vae_output_dtype())
if samples is None: if samples is None:
samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device) samples = torch.empty((pixel_samples.shape[0],) + tuple(out.shape[1:]), device=self.output_device, dtype=self.vae_output_dtype())
samples[x:x + batch_number] = out samples[x:x + batch_number] = out
except Exception as e: except Exception as e:
@ -1040,6 +1060,7 @@ class VAE:
do_tile = True do_tile = True
if do_tile: if do_tile:
comfy.model_management.soft_empty_cache()
if self.latent_dim == 3: if self.latent_dim == 3:
tile = 256 tile = 256
overlap = tile // 4 overlap = tile // 4

View File

@ -46,7 +46,7 @@ class ClipTokenWeightEncoder:
out, pooled = o[:2] out, pooled = o[:2]
if pooled is not None: if pooled is not None:
first_pooled = pooled[0:1].to(model_management.intermediate_device()) first_pooled = pooled[0:1].to(device=model_management.intermediate_device())
else: else:
first_pooled = pooled first_pooled = pooled
@ -63,16 +63,16 @@ class ClipTokenWeightEncoder:
output.append(z) output.append(z)
if (len(output) == 0): if (len(output) == 0):
r = (out[-1:].to(model_management.intermediate_device()), first_pooled) r = (out[-1:].to(device=model_management.intermediate_device()), first_pooled)
else: else:
r = (torch.cat(output, dim=-2).to(model_management.intermediate_device()), first_pooled) r = (torch.cat(output, dim=-2).to(device=model_management.intermediate_device()), first_pooled)
if len(o) > 2: if len(o) > 2:
extra = {} extra = {}
for k in o[2]: for k in o[2]:
v = o[2][k] v = o[2][k]
if k == "attention_mask": if k == "attention_mask":
v = v[:sections].flatten().unsqueeze(dim=0).to(model_management.intermediate_device()) v = v[:sections].flatten().unsqueeze(dim=0).to(device=model_management.intermediate_device())
extra[k] = v extra[k] = v
r = r + (extra,) r = r + (extra,)

View File

@ -1028,12 +1028,19 @@ class Qwen25_7BVLI(BaseLlama, BaseGenerate, torch.nn.Module):
grid = e.get("extra", None) grid = e.get("extra", None)
start = e.get("index") start = e.get("index")
if position_ids is None: if position_ids is None:
position_ids = torch.zeros((3, embeds.shape[1]), device=embeds.device) position_ids = torch.ones((3, embeds.shape[1]), device=embeds.device, dtype=torch.long)
position_ids[:, :start] = torch.arange(0, start, device=embeds.device) position_ids[:, :start] = torch.arange(0, start, device=embeds.device)
end = e.get("size") + start end = e.get("size") + start
len_max = int(grid.max()) // 2 len_max = int(grid.max()) // 2
start_next = len_max + start start_next = len_max + start
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device) if attention_mask is not None:
# Assign compact sequential positions to attended tokens only,
# skipping over padding so post-padding tokens aren't inflated.
after_mask = attention_mask[0, end:]
text_positions = after_mask.cumsum(0) - 1 + start_next + offset
position_ids[:, end:] = torch.where(after_mask.bool(), text_positions, position_ids[0, end:])
else:
position_ids[:, end:] = torch.arange(start_next + offset, start_next + (embeds.shape[1] - end) + offset, device=embeds.device)
position_ids[0, start:end] = start + offset position_ids[0, start:end] = start + offset
max_d = int(grid[0][1]) // 2 max_d = int(grid[0][1]) // 2
position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start] position_ids[1, start:end] = torch.arange(start + offset, start + max_d + offset, device=embeds.device).unsqueeze(1).repeat(1, math.ceil((end - start) / max_d)).flatten(0)[:end - start]

View File

@ -64,7 +64,13 @@ class LongCatImageBaseTokenizer(Qwen25_7BVLITokenizer):
return [output] return [output]
IMAGE_PAD_TOKEN_ID = 151655
class LongCatImageTokenizer(sd1_clip.SD1Tokenizer): class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
T2I_PREFIX = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
EDIT_PREFIX = "<|im_start|>system\nAs an image editing expert, first analyze the content and attributes of the input image(s). Then, based on the user's editing instructions, clearly and precisely determine how to modify the given image(s), ensuring that only the specified parts are altered and all other aspects remain consistent with the original(s).<|im_end|>\n<|im_start|>user\n<|vision_start|><|image_pad|><|vision_end|>"
SUFFIX = "<|im_end|>\n<|im_start|>assistant\n"
def __init__(self, embedding_directory=None, tokenizer_data={}): def __init__(self, embedding_directory=None, tokenizer_data={}):
super().__init__( super().__init__(
embedding_directory=embedding_directory, embedding_directory=embedding_directory,
@ -72,10 +78,8 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
name="qwen25_7b", name="qwen25_7b",
tokenizer=LongCatImageBaseTokenizer, tokenizer=LongCatImageBaseTokenizer,
) )
self.longcat_template_prefix = "<|im_start|>system\nAs an image captioning expert, generate a descriptive text prompt based on an image content, suitable for input to a text-to-image model.<|im_end|>\n<|im_start|>user\n"
self.longcat_template_suffix = "<|im_end|>\n<|im_start|>assistant\n"
def tokenize_with_weights(self, text, return_word_ids=False, **kwargs): def tokenize_with_weights(self, text, return_word_ids=False, images=None, **kwargs):
skip_template = False skip_template = False
if text.startswith("<|im_start|>"): if text.startswith("<|im_start|>"):
skip_template = True skip_template = True
@ -90,11 +94,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
text, return_word_ids=return_word_ids, disable_weights=True, **kwargs text, return_word_ids=return_word_ids, disable_weights=True, **kwargs
) )
else: else:
has_images = images is not None and len(images) > 0
template_prefix = self.EDIT_PREFIX if has_images else self.T2I_PREFIX
prefix_ids = base_tok.tokenizer( prefix_ids = base_tok.tokenizer(
self.longcat_template_prefix, add_special_tokens=False template_prefix, add_special_tokens=False
)["input_ids"] )["input_ids"]
suffix_ids = base_tok.tokenizer( suffix_ids = base_tok.tokenizer(
self.longcat_template_suffix, add_special_tokens=False self.SUFFIX, add_special_tokens=False
)["input_ids"] )["input_ids"]
prompt_tokens = base_tok.tokenize_with_weights( prompt_tokens = base_tok.tokenize_with_weights(
@ -106,6 +113,14 @@ class LongCatImageTokenizer(sd1_clip.SD1Tokenizer):
suffix_pairs = [(t, 1.0) for t in suffix_ids] suffix_pairs = [(t, 1.0) for t in suffix_ids]
combined = prefix_pairs + prompt_pairs + suffix_pairs combined = prefix_pairs + prompt_pairs + suffix_pairs
if has_images:
embed_count = 0
for i in range(len(combined)):
if combined[i][0] == IMAGE_PAD_TOKEN_ID and embed_count < len(images):
combined[i] = ({"type": "image", "data": images[embed_count], "original_type": "image"}, combined[i][1])
embed_count += 1
tokens = {"qwen25_7b": [combined]} tokens = {"qwen25_7b": [combined]}
return tokens return tokens

View File

@ -425,4 +425,7 @@ class Qwen2VLVisionTransformer(nn.Module):
hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention) hidden_states = block(hidden_states, position_embeddings, cu_seqlens_now, optimized_attention=optimized_attention)
hidden_states = self.merger(hidden_states) hidden_states = self.merger(hidden_states)
# Potentially important for spatially precise edits. This is present in the HF implementation.
reverse_indices = torch.argsort(window_index)
hidden_states = hidden_states[reverse_indices, :]
return hidden_states return hidden_states

View File

@ -20,6 +20,8 @@
import torch import torch
import math import math
import struct import struct
import ctypes
import os
import comfy.memory_management import comfy.memory_management
import safetensors.torch import safetensors.torch
import numpy as np import numpy as np
@ -32,7 +34,7 @@ from einops import rearrange
from comfy.cli_args import args from comfy.cli_args import args
import json import json
import time import time
import mmap import threading
import warnings import warnings
MMAP_TORCH_FILES = args.mmap_torch_files MMAP_TORCH_FILES = args.mmap_torch_files
@ -81,14 +83,17 @@ _TYPES = {
} }
def load_safetensors(ckpt): def load_safetensors(ckpt):
f = open(ckpt, "rb") import comfy_aimdo.model_mmap
mapping = mmap.mmap(f.fileno(), 0, access=mmap.ACCESS_READ)
mv = memoryview(mapping)
header_size = struct.unpack("<Q", mapping[:8])[0] f = open(ckpt, "rb", buffering=0)
header = json.loads(mapping[8:8+header_size].decode("utf-8")) model_mmap = comfy_aimdo.model_mmap.ModelMMAP(ckpt)
file_size = os.path.getsize(ckpt)
mv = memoryview((ctypes.c_uint8 * file_size).from_address(model_mmap.get()))
mv = mv[8 + header_size:] header_size = struct.unpack("<Q", mv[:8])[0]
header = json.loads(mv[8:8 + header_size].tobytes().decode("utf-8"))
mv = mv[(data_base_offset := 8 + header_size):]
sd = {} sd = {}
for name, info in header.items(): for name, info in header.items():
@ -102,7 +107,14 @@ def load_safetensors(ckpt):
with warnings.catch_warnings(): with warnings.catch_warnings():
#We are working with read-only RAM by design #We are working with read-only RAM by design
warnings.filterwarnings("ignore", message="The given buffer is not writable") warnings.filterwarnings("ignore", message="The given buffer is not writable")
sd[name] = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"]) tensor = torch.frombuffer(mv[start:end], dtype=_TYPES[info["dtype"]]).view(info["shape"])
storage = tensor.untyped_storage()
setattr(storage,
"_comfy_tensor_file_slice",
comfy.memory_management.TensorFileSlice(f, threading.get_ident(), data_base_offset + start, end - start))
setattr(storage, "_comfy_tensor_mmap_refs", (model_mmap, mv))
setattr(storage, "_comfy_tensor_mmap_touched", False)
sd[name] = tensor
return sd, header.get("__metadata__", {}), return sd, header.get("__metadata__", {}),
@ -885,6 +897,10 @@ def set_attr(obj, attr, value):
return prev return prev
def set_attr_param(obj, attr, value): def set_attr_param(obj, attr, value):
# Clone inference tensors (created under torch.inference_mode) since
# their version counter is frozen and nn.Parameter() cannot wrap them.
if (not torch.is_inference_mode_enabled()) and value.is_inference():
value = value.clone()
return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False)) return set_attr(obj, attr, torch.nn.Parameter(value, requires_grad=False))
def set_attr_buffer(obj, attr, value): def set_attr_buffer(obj, attr, value):
@ -1119,8 +1135,8 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
pbar.update(1) pbar.update(1)
continue continue
out = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) out = output[b:b+1].zero_()
out_div = torch.zeros([s.shape[0], out_channels] + mult_list_upscale(s.shape[2:]), device=output_device) out_div = torch.zeros([s.shape[0], 1] + mult_list_upscale(s.shape[2:]), device=output_device)
positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)] positions = [range(0, s.shape[d+2] - overlap[d], tile[d] - overlap[d]) if s.shape[d+2] > tile[d] else [0] for d in range(dims)]
@ -1135,7 +1151,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
upscaled.append(round(get_pos(d, pos))) upscaled.append(round(get_pos(d, pos)))
ps = function(s_in).to(output_device) ps = function(s_in).to(output_device)
mask = torch.ones_like(ps) mask = torch.ones([1, 1] + list(ps.shape[2:]), device=output_device)
for d in range(2, dims + 2): for d in range(2, dims + 2):
feather = round(get_scale(d - 2, overlap[d - 2])) feather = round(get_scale(d - 2, overlap[d - 2]))
@ -1158,7 +1174,7 @@ def tiled_scale_multidim(samples, function, tile=(64, 64), overlap=8, upscale_am
if pbar is not None: if pbar is not None:
pbar.update(1) pbar.update(1)
output[b:b+1] = out/out_div out.div_(out_div)
return output return output
def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None): def tiled_scale(samples, function, tile_x=64, tile_y=64, overlap = 8, upscale_amount = 4, out_channels = 3, output_device="cpu", pbar = None):

View File

@ -25,6 +25,7 @@ class ComfyAPI_latest(ComfyAPIBase):
super().__init__() super().__init__()
self.node_replacement = self.NodeReplacement() self.node_replacement = self.NodeReplacement()
self.execution = self.Execution() self.execution = self.Execution()
self.caching = self.Caching()
class NodeReplacement(ProxiedSingleton): class NodeReplacement(ProxiedSingleton):
async def register(self, node_replace: io.NodeReplace) -> None: async def register(self, node_replace: io.NodeReplace) -> None:
@ -84,6 +85,36 @@ class ComfyAPI_latest(ComfyAPIBase):
image=to_display, image=to_display,
) )
class Caching(ProxiedSingleton):
"""
External cache provider API for sharing cached node outputs
across ComfyUI instances.
Example::
from comfy_api.latest import Caching
class MyCacheProvider(Caching.CacheProvider):
async def on_lookup(self, context):
... # check external storage
async def on_store(self, context, value):
... # store to external storage
Caching.register_provider(MyCacheProvider())
"""
from ._caching import CacheProvider, CacheContext, CacheValue
async def register_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Register an external cache provider. Providers are called in registration order."""
from comfy_execution.cache_provider import register_cache_provider
register_cache_provider(provider)
async def unregister_provider(self, provider: "ComfyAPI_latest.Caching.CacheProvider") -> None:
"""Unregister a previously registered cache provider."""
from comfy_execution.cache_provider import unregister_cache_provider
unregister_cache_provider(provider)
class ComfyExtension(ABC): class ComfyExtension(ABC):
async def on_load(self) -> None: async def on_load(self) -> None:
""" """
@ -116,6 +147,9 @@ class Types:
VOXEL = VOXEL VOXEL = VOXEL
File3D = File3D File3D = File3D
Caching = ComfyAPI_latest.Caching
ComfyAPI = ComfyAPI_latest ComfyAPI = ComfyAPI_latest
# Create a synchronous version of the API # Create a synchronous version of the API
@ -135,6 +169,7 @@ __all__ = [
"Input", "Input",
"InputImpl", "InputImpl",
"Types", "Types",
"Caching",
"ComfyExtension", "ComfyExtension",
"io", "io",
"IO", "IO",

View File

@ -0,0 +1,42 @@
from abc import ABC, abstractmethod
from typing import Optional
from dataclasses import dataclass
@dataclass
class CacheContext:
node_id: str
class_type: str
cache_key_hash: str # SHA256 hex digest
@dataclass
class CacheValue:
outputs: list
ui: dict = None
class CacheProvider(ABC):
"""Abstract base class for external cache providers.
Exceptions from provider methods are caught by the caller and never break execution.
"""
@abstractmethod
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
"""Called on local cache miss. Return CacheValue if found, None otherwise."""
pass
@abstractmethod
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
"""Called after local store. Dispatched via asyncio.create_task."""
pass
def should_cache(self, context: CacheContext, value: Optional[CacheValue] = None) -> bool:
"""Return False to skip external caching for this node. Default: True."""
return True
def on_prompt_start(self, prompt_id: str) -> None:
pass
def on_prompt_end(self, prompt_id: str) -> None:
pass

View File

@ -272,7 +272,7 @@ class VideoFromFile(VideoInput):
has_first_frame = False has_first_frame = False
for frame in frames: for frame in frames:
offset_seconds = start_time - frame.pts * audio_stream.time_base offset_seconds = start_time - frame.pts * audio_stream.time_base
to_skip = int(offset_seconds * audio_stream.sample_rate) to_skip = max(0, int(offset_seconds * audio_stream.sample_rate))
if to_skip < frame.samples: if to_skip < frame.samples:
has_first_frame = True has_first_frame = True
break break
@ -280,7 +280,7 @@ class VideoFromFile(VideoInput):
audio_frames.append(frame.to_ndarray()[..., to_skip:]) audio_frames.append(frame.to_ndarray()[..., to_skip:])
for frame in frames: for frame in frames:
if frame.time > start_time + self.__duration: if self.__duration and frame.time > start_time + self.__duration:
break break
audio_frames.append(frame.to_ndarray()) # shape: (channels, samples) audio_frames.append(frame.to_ndarray()) # shape: (channels, samples)
if len(audio_frames) > 0: if len(audio_frames) > 0:

View File

@ -297,7 +297,7 @@ class Float(ComfyTypeIO):
'''Float input.''' '''Float input.'''
def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None, def __init__(self, id: str, display_name: str=None, optional=False, tooltip: str=None, lazy: bool=None,
default: float=None, min: float=None, max: float=None, step: float=None, round: float=None, default: float=None, min: float=None, max: float=None, step: float=None, round: float=None,
display_mode: NumberDisplay=None, gradient_stops: list[list[float]]=None, display_mode: NumberDisplay=None, gradient_stops: list[dict]=None,
socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None): socketless: bool=None, force_input: bool=None, extra_dict=None, raw_link: bool=None, advanced: bool=None):
super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced) super().__init__(id, display_name, optional, tooltip, lazy, default, socketless, None, force_input, extra_dict, raw_link, advanced)
self.min = min self.min = min

View File

@ -67,6 +67,7 @@ class GeminiPart(BaseModel):
inlineData: GeminiInlineData | None = Field(None) inlineData: GeminiInlineData | None = Field(None)
fileData: GeminiFileData | None = Field(None) fileData: GeminiFileData | None = Field(None)
text: str | None = Field(None) text: str | None = Field(None)
thought: bool | None = Field(None)
class GeminiTextPart(BaseModel): class GeminiTextPart(BaseModel):

View File

@ -0,0 +1,43 @@
from pydantic import BaseModel, Field
class QuiverImageObject(BaseModel):
url: str = Field(...)
class QuiverTextToSVGRequest(BaseModel):
model: str = Field(default="arrow-preview")
prompt: str = Field(...)
instructions: str | None = Field(default=None)
references: list[QuiverImageObject] | None = Field(default=None, max_length=4)
temperature: float | None = Field(default=None, ge=0, le=2)
top_p: float | None = Field(default=None, ge=0, le=1)
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
class QuiverImageToSVGRequest(BaseModel):
model: str = Field(default="arrow-preview")
image: QuiverImageObject = Field(...)
auto_crop: bool | None = Field(default=None)
target_size: int | None = Field(default=None, ge=128, le=4096)
temperature: float | None = Field(default=None, ge=0, le=2)
top_p: float | None = Field(default=None, ge=0, le=1)
presence_penalty: float | None = Field(default=None, ge=-2, le=2)
class QuiverSVGResponseItem(BaseModel):
svg: str = Field(...)
mime_type: str | None = Field(default="image/svg+xml")
class QuiverSVGUsage(BaseModel):
total_tokens: int | None = Field(default=None)
input_tokens: int | None = Field(default=None)
output_tokens: int | None = Field(default=None)
class QuiverSVGResponse(BaseModel):
id: str | None = Field(default=None)
created: int | None = Field(default=None)
data: list[QuiverSVGResponseItem] = Field(...)
usage: QuiverSVGUsage | None = Field(default=None)

View File

@ -0,0 +1,68 @@
from pydantic import BaseModel, Field
class RevePostprocessingOperation(BaseModel):
process: str = Field(..., description="The postprocessing operation: upscale or remove_background.")
upscale_factor: int | None = Field(
None,
description="Upscale factor (2, 3, or 4). Only used when process is upscale.",
ge=2,
le=4,
)
class ReveImageCreateRequest(BaseModel):
prompt: str = Field(...)
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageEditRequest(BaseModel):
edit_instruction: str = Field(...)
reference_image: str = Field(..., description="A base64 encoded image to use as reference for the edit.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageRemixRequest(BaseModel):
prompt: str = Field(...)
reference_images: list[str] = Field(..., description="A list of 1-6 base64 encoded reference images.")
aspect_ratio: str | None = Field(...)
version: str = Field(...)
test_time_scaling: int | None = Field(
...,
description="If included, the model will spend more effort making better images. Values between 1 and 15.",
ge=1,
le=15,
)
postprocessing: list[RevePostprocessingOperation] | None = Field(
None, description="Optional postprocessing operations to apply after generation."
)
class ReveImageResponse(BaseModel):
image: str | None = Field(None, description="The base64 encoded image data.")
request_id: str | None = Field(None, description="A unique id for the request.")
credits_used: float | None = Field(None, description="The number of credits used for this request.")
version: str | None = Field(None, description="The specific model version used.")
content_violation: bool | None = Field(
None, description="Indicates whether the generated image violates the content policy."
)

View File

@ -47,6 +47,10 @@ SEEDREAM_MODELS = {
BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" BYTEPLUS_TASK_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks"
BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id} BYTEPLUS_TASK_STATUS_ENDPOINT = "/proxy/byteplus/api/v3/contents/generations/tasks" # + /{task_id}
DEPRECATED_MODELS = {"seedance-1-0-lite-t2v-250428", "seedance-1-0-lite-i2v-250428"}
logger = logging.getLogger(__name__)
def get_image_url_from_response(response: ImageTaskCreationResponse) -> str: def get_image_url_from_response(response: ImageTaskCreationResponse) -> str:
if response.error: if response.error:
@ -135,6 +139,7 @@ class ByteDanceImageNode(IO.ComfyNode):
price_badge=IO.PriceBadge( price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03}""", expr="""{"type":"usd","usd":0.03}""",
), ),
is_deprecated=True,
) )
@classmethod @classmethod
@ -942,7 +947,7 @@ class ByteDanceImageReferenceNode(IO.ComfyNode):
] ]
return await process_video_task( return await process_video_task(
cls, cls,
payload=Image2VideoTaskCreationRequest(model=model, content=x), payload=Image2VideoTaskCreationRequest(model=model, content=x, generate_audio=None),
estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))), estimated_duration=max(1, math.ceil(VIDEO_TASKS_EXECUTION_TIME[model][resolution] * (duration / 10.0))),
) )
@ -952,6 +957,12 @@ async def process_video_task(
payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest, payload: Text2VideoTaskCreationRequest | Image2VideoTaskCreationRequest,
estimated_duration: int | None, estimated_duration: int | None,
) -> IO.NodeOutput: ) -> IO.NodeOutput:
if payload.model in DEPRECATED_MODELS:
logger.warning(
"Model '%s' is deprecated and will be deactivated on May 13, 2026. "
"Please switch to a newer model. Recommended: seedance-1-0-pro-fast-251015.",
payload.model,
)
initial_response = await sync_op( initial_response = await sync_op(
cls, cls,
ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"), ApiEndpoint(path=BYTEPLUS_TASK_ENDPOINT, method="POST"),

View File

@ -63,7 +63,7 @@ GEMINI_IMAGE_2_PRICE_BADGE = IO.PriceBadge(
$m := widgets.model; $m := widgets.model;
$r := widgets.resolution; $r := widgets.resolution;
$isFlash := $contains($m, "nano banana 2"); $isFlash := $contains($m, "nano banana 2");
$flashPrices := {"1k": 0.0696, "2k": 0.0696, "4k": 0.123}; $flashPrices := {"1k": 0.0696, "2k": 0.1014, "4k": 0.154};
$proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24}; $proPrices := {"1k": 0.134, "2k": 0.134, "4k": 0.24};
$prices := $isFlash ? $flashPrices : $proPrices; $prices := $isFlash ? $flashPrices : $proPrices;
{"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}} {"type":"usd","usd": $lookup($prices, $r), "format":{"suffix":"/Image","approximate":true}}
@ -188,10 +188,12 @@ def get_text_from_response(response: GeminiGenerateContentResponse) -> str:
return "\n".join([part.text for part in parts]) return "\n".join([part.text for part in parts])
async def get_image_from_response(response: GeminiGenerateContentResponse) -> Input.Image: async def get_image_from_response(response: GeminiGenerateContentResponse, thought: bool = False) -> Input.Image:
image_tensors: list[Input.Image] = [] image_tensors: list[Input.Image] = []
parts = get_parts_by_type(response, "image/*") parts = get_parts_by_type(response, "image/*")
for part in parts: for part in parts:
if (part.thought is True) != thought:
continue
if part.inlineData: if part.inlineData:
image_data = base64.b64decode(part.inlineData.data) image_data = base64.b64decode(part.inlineData.data)
returned_image = bytesio_to_image_tensor(BytesIO(image_data)) returned_image = bytesio_to_image_tensor(BytesIO(image_data))
@ -931,6 +933,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
outputs=[ outputs=[
IO.Image.Output(), IO.Image.Output(),
IO.String.Output(), IO.String.Output(),
IO.Image.Output(
display_name="thought_image",
tooltip="First image from the model's thinking process. "
"Only available with thinking_level HIGH and IMAGE+TEXT modality.",
),
], ],
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
@ -992,7 +999,11 @@ class GeminiNanoBanana2(IO.ComfyNode):
response_model=GeminiGenerateContentResponse, response_model=GeminiGenerateContentResponse,
price_extractor=calculate_tokens_price, price_extractor=calculate_tokens_price,
) )
return IO.NodeOutput(await get_image_from_response(response), get_text_from_response(response)) return IO.NodeOutput(
await get_image_from_response(response),
get_text_from_response(response),
await get_image_from_response(response, thought=True),
)
class GeminiExtension(ComfyExtension): class GeminiExtension(ComfyExtension):

View File

@ -1,3 +1,7 @@
import zipfile
from io import BytesIO
import torch
from typing_extensions import override from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input, Types from comfy_api.latest import IO, ComfyExtension, Input, Types
@ -17,7 +21,10 @@ from comfy_api_nodes.apis.hunyuan3d import (
) )
from comfy_api_nodes.util import ( from comfy_api_nodes.util import (
ApiEndpoint, ApiEndpoint,
bytesio_to_image_tensor,
download_url_to_bytesio,
download_url_to_file_3d, download_url_to_file_3d,
download_url_to_image_tensor,
downscale_image_tensor_by_max_side, downscale_image_tensor_by_max_side,
poll_op, poll_op,
sync_op, sync_op,
@ -36,6 +43,68 @@ def _is_tencent_rate_limited(status: int, body: object) -> bool:
) )
class ObjZipResult:
__slots__ = ("obj", "texture", "metallic", "normal", "roughness")
def __init__(
self,
obj: Types.File3D,
texture: Input.Image | None = None,
metallic: Input.Image | None = None,
normal: Input.Image | None = None,
roughness: Input.Image | None = None,
):
self.obj = obj
self.texture = texture
self.metallic = metallic
self.normal = normal
self.roughness = roughness
async def download_and_extract_obj_zip(url: str) -> ObjZipResult:
"""The Tencent API returns OBJ results as ZIP archives containing the .obj mesh, and texture images.
When PBR is enabled, the ZIP may contain additional metallic, normal, and roughness maps
identified by their filename suffixes.
"""
data = BytesIO()
await download_url_to_bytesio(url, data)
data.seek(0)
if not zipfile.is_zipfile(data):
data.seek(0)
return ObjZipResult(obj=Types.File3D(source=data, file_format="obj"))
data.seek(0)
obj_bytes = None
textures: dict[str, Input.Image] = {}
with zipfile.ZipFile(data) as zf:
for name in zf.namelist():
lower = name.lower()
if lower.endswith(".obj"):
obj_bytes = zf.read(name)
elif any(lower.endswith(ext) for ext in (".png", ".jpg", ".jpeg", ".bmp", ".tiff", ".webp")):
stem = lower.rsplit(".", 1)[0]
tensor = bytesio_to_image_tensor(BytesIO(zf.read(name)), mode="RGB")
matched_key = "texture"
for suffix, key in {
"_metallic": "metallic",
"_normal": "normal",
"_roughness": "roughness",
}.items():
if stem.endswith(suffix):
matched_key = key
break
textures[matched_key] = tensor
if obj_bytes is None:
raise ValueError("ZIP archive does not contain an OBJ file.")
return ObjZipResult(
obj=Types.File3D(source=BytesIO(obj_bytes), file_format="obj"),
texture=textures.get("texture"),
metallic=textures.get("metallic"),
normal=textures.get("normal"),
roughness=textures.get("roughness"),
)
def get_file_from_response( def get_file_from_response(
response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True response_objs: list[ResultFile3D], file_type: str, raise_if_not_found: bool = True
) -> ResultFile3D | None: ) -> ResultFile3D | None:
@ -93,6 +162,7 @@ class TencentTextToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"), IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"), IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
], ],
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
@ -151,14 +221,14 @@ class TencentTextToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse, response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status, status_extractor=lambda r: r.Status,
) )
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput( return IO.NodeOutput(
f"{task_id}.glb", f"{task_id}.glb",
await download_url_to_file_3d( await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
), ),
await download_url_to_file_3d( obj_result.obj,
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id obj_result.texture,
),
) )
@ -211,6 +281,10 @@ class TencentImageToModelNode(IO.ComfyNode):
IO.String.Output(display_name="model_file"), # for backward compatibility only IO.String.Output(display_name="model_file"), # for backward compatibility only
IO.File3DGLB.Output(display_name="GLB"), IO.File3DGLB.Output(display_name="GLB"),
IO.File3DOBJ.Output(display_name="OBJ"), IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
IO.Image.Output(display_name="optional_metallic"),
IO.Image.Output(display_name="optional_normal"),
IO.Image.Output(display_name="optional_roughness"),
], ],
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
@ -304,14 +378,17 @@ class TencentImageToModelNode(IO.ComfyNode):
response_model=To3DProTaskResultResponse, response_model=To3DProTaskResultResponse,
status_extractor=lambda r: r.Status, status_extractor=lambda r: r.Status,
) )
obj_result = await download_and_extract_obj_zip(get_file_from_response(result.ResultFile3Ds, "obj").Url)
return IO.NodeOutput( return IO.NodeOutput(
f"{task_id}.glb", f"{task_id}.glb",
await download_url_to_file_3d( await download_url_to_file_3d(
get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb", task_id=task_id
), ),
await download_url_to_file_3d( obj_result.obj,
get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj", task_id=task_id obj_result.texture,
), obj_result.metallic if obj_result.metallic is not None else torch.zeros(1, 1, 1, 3),
obj_result.normal if obj_result.normal is not None else torch.zeros(1, 1, 1, 3),
obj_result.roughness if obj_result.roughness is not None else torch.zeros(1, 1, 1, 3),
) )
@ -431,7 +508,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
], ],
outputs=[ outputs=[
IO.File3DGLB.Output(display_name="GLB"), IO.File3DGLB.Output(display_name="GLB"),
IO.File3DFBX.Output(display_name="FBX"), IO.File3DOBJ.Output(display_name="OBJ"),
IO.Image.Output(display_name="texture_image"),
], ],
hidden=[ hidden=[
IO.Hidden.auth_token_comfy_org, IO.Hidden.auth_token_comfy_org,
@ -480,7 +558,8 @@ class Tencent3DTextureEditNode(IO.ComfyNode):
) )
return IO.NodeOutput( return IO.NodeOutput(
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "glb").Url, "glb"),
await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "fbx").Url, "fbx"), await download_url_to_file_3d(get_file_from_response(result.ResultFile3Ds, "obj").Url, "obj"),
await download_url_to_image_tensor(get_file_from_response(result.ResultFile3Ds, "texture_image").Url),
) )
@ -654,7 +733,7 @@ class TencentHunyuan3DExtension(ComfyExtension):
TencentTextToModelNode, TencentTextToModelNode,
TencentImageToModelNode, TencentImageToModelNode,
TencentModelTo3DUVNode, TencentModelTo3DUVNode,
# Tencent3DTextureEditNode, Tencent3DTextureEditNode,
Tencent3DPartNode, Tencent3DPartNode,
TencentSmartTopologyNode, TencentSmartTopologyNode,
] ]

View File

@ -1459,6 +1459,7 @@ class OmniProEditVideoNode(IO.ComfyNode):
node_id="KlingOmniProEditVideoNode", node_id="KlingOmniProEditVideoNode",
display_name="Kling 3.0 Omni Edit Video", display_name="Kling 3.0 Omni Edit Video",
category="api node/video/Kling", category="api node/video/Kling",
essentials_category="Video Generation",
description="Edit an existing video with the latest model from Kling.", description="Edit an existing video with the latest model from Kling.",
inputs=[ inputs=[
IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]), IO.Combo.Input("model_name", options=["kling-v3-omni", "kling-video-o1"]),

View File

@ -0,0 +1,291 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension
from comfy_api_nodes.apis.quiver import (
QuiverImageObject,
QuiverImageToSVGRequest,
QuiverSVGResponse,
QuiverTextToSVGRequest,
)
from comfy_api_nodes.util import (
ApiEndpoint,
sync_op,
upload_image_to_comfyapi,
validate_string,
)
from comfy_extras.nodes_images import SVG
class QuiverTextToSVGNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="QuiverTextToSVGNode",
display_name="Quiver Text to SVG",
category="api node/image/Quiver",
description="Generate an SVG from a text prompt using Quiver AI.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired SVG output.",
),
IO.String.Input(
"instructions",
multiline=True,
default="",
tooltip="Additional style or formatting guidance.",
optional=True,
),
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="ref_",
min=0,
max=4,
),
tooltip="Up to 4 reference images to guide the generation.",
optional=True,
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"arrow-preview",
[
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
],
),
],
tooltip="Model to use for SVG generation.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.SVG.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
seed: int,
instructions: str = None,
reference_images: IO.Autogrow.Type = None,
) -> IO.NodeOutput:
validate_string(prompt, strip_whitespace=False, min_length=1)
references = None
if reference_images:
references = []
for key in reference_images:
url = await upload_image_to_comfyapi(cls, reference_images[key])
references.append(QuiverImageObject(url=url))
if len(references) > 4:
raise ValueError("Maximum 4 reference images are allowed.")
instructions_val = instructions.strip() if instructions else None
if instructions_val == "":
instructions_val = None
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/quiver/v1/svgs/generations", method="POST"),
response_model=QuiverSVGResponse,
data=QuiverTextToSVGRequest(
model=model["model"],
prompt=prompt,
instructions=instructions_val,
references=references,
temperature=model.get("temperature"),
top_p=model.get("top_p"),
presence_penalty=model.get("presence_penalty"),
),
)
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
return IO.NodeOutput(SVG(svg_data))
class QuiverImageToSVGNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="QuiverImageToSVGNode",
display_name="Quiver Image to SVG",
category="api node/image/Quiver",
description="Vectorize a raster image into SVG using Quiver AI.",
inputs=[
IO.Image.Input(
"image",
tooltip="Input image to vectorize.",
),
IO.Boolean.Input(
"auto_crop",
default=False,
tooltip="Automatically crop to the dominant subject.",
),
IO.DynamicCombo.Input(
"model",
options=[
IO.DynamicCombo.Option(
"arrow-preview",
[
IO.Int.Input(
"target_size",
default=1024,
min=128,
max=4096,
tooltip="Square resize target in pixels.",
),
IO.Float.Input(
"temperature",
default=1.0,
min=0.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Randomness control. Higher values increase randomness.",
advanced=True,
),
IO.Float.Input(
"top_p",
default=1.0,
min=0.05,
max=1.0,
step=0.05,
display_mode=IO.NumberDisplay.slider,
tooltip="Nucleus sampling parameter.",
advanced=True,
),
IO.Float.Input(
"presence_penalty",
default=0.0,
min=-2.0,
max=2.0,
step=0.1,
display_mode=IO.NumberDisplay.slider,
tooltip="Token presence penalty.",
advanced=True,
),
],
),
],
tooltip="Model to use for SVG vectorization.",
),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed to determine if node should re-run; "
"actual results are nondeterministic regardless of seed.",
),
],
outputs=[
IO.SVG.Output(),
],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.429}""",
),
)
@classmethod
async def execute(
cls,
image,
auto_crop: bool,
model: dict,
seed: int,
) -> IO.NodeOutput:
image_url = await upload_image_to_comfyapi(cls, image)
response = await sync_op(
cls,
ApiEndpoint(path="/proxy/quiver/v1/svgs/vectorizations", method="POST"),
response_model=QuiverSVGResponse,
data=QuiverImageToSVGRequest(
model=model["model"],
image=QuiverImageObject(url=image_url),
auto_crop=auto_crop if auto_crop else None,
target_size=model.get("target_size"),
temperature=model.get("temperature"),
top_p=model.get("top_p"),
presence_penalty=model.get("presence_penalty"),
),
)
svg_data = [BytesIO(item.svg.encode("utf-8")) for item in response.data]
return IO.NodeOutput(SVG(svg_data))
class QuiverExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
QuiverTextToSVGNode,
QuiverImageToSVGNode,
]
async def comfy_entrypoint() -> QuiverExtension:
return QuiverExtension()

View File

@ -833,6 +833,7 @@ class RecraftVectorizeImageNode(IO.ComfyNode):
node_id="RecraftVectorizeImageNode", node_id="RecraftVectorizeImageNode",
display_name="Recraft Vectorize Image", display_name="Recraft Vectorize Image",
category="api node/image/Recraft", category="api node/image/Recraft",
essentials_category="Image Tools",
description="Generates SVG synchronously from an input image.", description="Generates SVG synchronously from an input image.",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),

View File

@ -0,0 +1,395 @@
from io import BytesIO
from typing_extensions import override
from comfy_api.latest import IO, ComfyExtension, Input
from comfy_api_nodes.apis.reve import (
ReveImageCreateRequest,
ReveImageEditRequest,
ReveImageRemixRequest,
RevePostprocessingOperation,
)
from comfy_api_nodes.util import (
ApiEndpoint,
bytesio_to_image_tensor,
sync_op_raw,
tensor_to_base64_string,
validate_string,
)
def _build_postprocessing(upscale: dict, remove_background: bool) -> list[RevePostprocessingOperation] | None:
ops = []
if upscale["upscale"] == "enabled":
ops.append(
RevePostprocessingOperation(
process="upscale",
upscale_factor=upscale["upscale_factor"],
)
)
if remove_background:
ops.append(RevePostprocessingOperation(process="remove_background"))
return ops or None
def _postprocessing_inputs():
return [
IO.DynamicCombo.Input(
"upscale",
options=[
IO.DynamicCombo.Option("disabled", []),
IO.DynamicCombo.Option(
"enabled",
[
IO.Int.Input(
"upscale_factor",
default=2,
min=2,
max=4,
step=1,
tooltip="Upscale factor (2x, 3x, or 4x).",
),
],
),
],
tooltip="Upscale the generated image. May add additional cost.",
),
IO.Boolean.Input(
"remove_background",
default=False,
tooltip="Remove the background from the generated image. May add additional cost.",
),
]
def _reve_price_extractor(headers: dict) -> float | None:
credits_used = headers.get("x-reve-credits-used")
if credits_used is not None:
return float(credits_used) / 524.48
return None
def _reve_response_header_validator(headers: dict) -> None:
error_code = headers.get("x-reve-error-code")
if error_code:
raise ValueError(f"Reve API error: {error_code}")
if headers.get("x-reve-content-violation", "").lower() == "true":
raise ValueError("The generated image was flagged for content policy violation.")
def _model_inputs(versions: list[str], aspect_ratios: list[str]):
return [
IO.DynamicCombo.Option(
version,
[
IO.Combo.Input(
"aspect_ratio",
options=aspect_ratios,
tooltip="Aspect ratio of the output image.",
),
IO.Int.Input(
"test_time_scaling",
default=1,
min=1,
max=5,
step=1,
tooltip="Higher values produce better images but cost more credits.",
advanced=True,
),
],
)
for version in versions
]
class ReveImageCreateNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageCreateNode",
display_name="Reve Image Create",
category="api node/image/Reve",
description="Generate images from text descriptions using Reve.",
inputs=[
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-create@20250915"],
aspect_ratios=["3:2", "16:9", "9:16", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for generation.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
expr="""{"type":"usd","usd":0.03432,"format":{"approximate":true,"note":"(base)"}}""",
),
)
@classmethod
async def execute(
cls,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/create",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageCreateRequest(
prompt=prompt,
aspect_ratio=model["aspect_ratio"],
version=model["model"],
test_time_scaling=model["test_time_scaling"],
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageEditNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageEditNode",
display_name="Reve Image Edit",
category="api node/image/Reve",
description="Edit images using natural language instructions with Reve.",
inputs=[
IO.Image.Input("image", tooltip="The image to edit."),
IO.String.Input(
"edit_instruction",
multiline=True,
default="",
tooltip="Text description of how to edit the image. Maximum 2560 characters.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-edit@20250915", "reve-edit-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for editing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
image: Input.Image,
edit_instruction: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(edit_instruction, min_length=1, max_length=2560)
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/edit",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageEditRequest(
edit_instruction=edit_instruction,
reference_image=tensor_to_base64_string(image),
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveImageRemixNode(IO.ComfyNode):
@classmethod
def define_schema(cls):
return IO.Schema(
node_id="ReveImageRemixNode",
display_name="Reve Image Remix",
category="api node/image/Reve",
description="Combine reference images with text prompts to create new images using Reve.",
inputs=[
IO.Autogrow.Input(
"reference_images",
template=IO.Autogrow.TemplatePrefix(
IO.Image.Input("image"),
prefix="image_",
min=1,
max=6,
),
),
IO.String.Input(
"prompt",
multiline=True,
default="",
tooltip="Text description of the desired image. "
"May include XML img tags to reference specific images by index, "
"e.g. <img>0</img>, <img>1</img>, etc.",
),
IO.DynamicCombo.Input(
"model",
options=_model_inputs(
["reve-remix@20250915", "reve-remix-fast@20251030"],
aspect_ratios=["auto", "16:9", "9:16", "3:2", "2:3", "4:3", "3:4", "1:1"],
),
tooltip="Model version to use for remixing.",
),
*_postprocessing_inputs(),
IO.Int.Input(
"seed",
default=0,
min=0,
max=2147483647,
control_after_generate=True,
tooltip="Seed controls whether the node should re-run; "
"results are non-deterministic regardless of seed.",
),
],
outputs=[IO.Image.Output()],
hidden=[
IO.Hidden.auth_token_comfy_org,
IO.Hidden.api_key_comfy_org,
IO.Hidden.unique_id,
],
is_api_node=True,
price_badge=IO.PriceBadge(
depends_on=IO.PriceBadgeDepends(
widgets=["model"],
),
expr="""
(
$isFast := $contains(widgets.model, "fast");
$base := $isFast ? 0.01001 : 0.0572;
{"type": "usd", "usd": $base, "format": {"approximate": true, "note": "(base)"}}
)
""",
),
)
@classmethod
async def execute(
cls,
reference_images: IO.Autogrow.Type,
prompt: str,
model: dict,
upscale: dict,
remove_background: bool,
seed: int,
) -> IO.NodeOutput:
validate_string(prompt, min_length=1, max_length=2560)
if not reference_images:
raise ValueError("At least one reference image is required.")
ref_base64_list = []
for key in reference_images:
ref_base64_list.append(tensor_to_base64_string(reference_images[key]))
if len(ref_base64_list) > 6:
raise ValueError("Maximum 6 reference images are allowed.")
tts = model["test_time_scaling"]
ar = model["aspect_ratio"]
response = await sync_op_raw(
cls,
ApiEndpoint(
path="/proxy/reve/v1/image/remix",
method="POST",
headers={"Accept": "image/webp"},
),
as_binary=True,
price_extractor=_reve_price_extractor,
response_header_validator=_reve_response_header_validator,
data=ReveImageRemixRequest(
prompt=prompt,
reference_images=ref_base64_list,
aspect_ratio=ar if ar != "auto" else None,
version=model["model"],
test_time_scaling=tts if tts and tts > 1 else None,
postprocessing=_build_postprocessing(upscale, remove_background),
),
)
return IO.NodeOutput(bytesio_to_image_tensor(BytesIO(response)))
class ReveExtension(ComfyExtension):
@override
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
return [
ReveImageCreateNode,
ReveImageEditNode,
ReveImageRemixNode,
]
async def comfy_entrypoint() -> ReveExtension:
return ReveExtension()

View File

@ -67,6 +67,7 @@ class _RequestConfig:
progress_origin_ts: float | None = None progress_origin_ts: float | None = None
price_extractor: Callable[[dict[str, Any]], float | None] | None = None price_extractor: Callable[[dict[str, Any]], float | None] | None = None
is_rate_limited: Callable[[int, Any], bool] | None = None is_rate_limited: Callable[[int, Any], bool] | None = None
response_header_validator: Callable[[dict[str, str]], None] | None = None
@dataclass @dataclass
@ -202,11 +203,13 @@ async def sync_op_raw(
monitor_progress: bool = True, monitor_progress: bool = True,
max_retries_on_rate_limit: int = 16, max_retries_on_rate_limit: int = 16,
is_rate_limited: Callable[[int, Any], bool] | None = None, is_rate_limited: Callable[[int, Any], bool] | None = None,
response_header_validator: Callable[[dict[str, str]], None] | None = None,
) -> dict[str, Any] | bytes: ) -> dict[str, Any] | bytes:
""" """
Make a single network request. Make a single network request.
- If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON). - If as_binary=False (default): returns JSON dict (or {'_raw': '<text>'} if non-JSON).
- If as_binary=True: returns bytes. - If as_binary=True: returns bytes.
- response_header_validator: optional callback receiving response headers dict
""" """
if isinstance(data, BaseModel): if isinstance(data, BaseModel):
data = data.model_dump(exclude_none=True) data = data.model_dump(exclude_none=True)
@ -232,6 +235,7 @@ async def sync_op_raw(
price_extractor=price_extractor, price_extractor=price_extractor,
max_retries_on_rate_limit=max_retries_on_rate_limit, max_retries_on_rate_limit=max_retries_on_rate_limit,
is_rate_limited=is_rate_limited, is_rate_limited=is_rate_limited,
response_header_validator=response_header_validator,
) )
return await _request_base(cfg, expect_binary=as_binary) return await _request_base(cfg, expect_binary=as_binary)
@ -769,6 +773,12 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total cfg.node_cls, cfg.wait_label, int(now - start_time), cfg.estimated_total
) )
bytes_payload = bytes(buff) bytes_payload = bytes(buff)
resp_headers = {k.lower(): v for k, v in resp.headers.items()}
if cfg.price_extractor:
with contextlib.suppress(Exception):
extracted_price = cfg.price_extractor(resp_headers)
if cfg.response_header_validator:
cfg.response_header_validator(resp_headers)
operation_succeeded = True operation_succeeded = True
final_elapsed_seconds = int(time.monotonic() - start_time) final_elapsed_seconds = int(time.monotonic() - start_time)
request_logger.log_request_response( request_logger.log_request_response(
@ -776,7 +786,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
request_method=method, request_method=method,
request_url=url, request_url=url,
response_status_code=resp.status, response_status_code=resp.status,
response_headers=dict(resp.headers), response_headers=resp_headers,
response_content=bytes_payload, response_content=bytes_payload,
) )
return bytes_payload return bytes_payload

View File

@ -0,0 +1,138 @@
from typing import Any, Optional, Tuple, List
import hashlib
import json
import logging
import threading
# Public types — source of truth is comfy_api.latest._caching
from comfy_api.latest._caching import CacheProvider, CacheContext, CacheValue # noqa: F401 (re-exported)
_logger = logging.getLogger(__name__)
_providers: List[CacheProvider] = []
_providers_lock = threading.Lock()
_providers_snapshot: Tuple[CacheProvider, ...] = ()
def register_cache_provider(provider: CacheProvider) -> None:
"""Register an external cache provider. Providers are called in registration order."""
global _providers_snapshot
with _providers_lock:
if provider in _providers:
_logger.warning(f"Provider {provider.__class__.__name__} already registered")
return
_providers.append(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Registered cache provider: {provider.__class__.__name__}")
def unregister_cache_provider(provider: CacheProvider) -> None:
global _providers_snapshot
with _providers_lock:
try:
_providers.remove(provider)
_providers_snapshot = tuple(_providers)
_logger.debug(f"Unregistered cache provider: {provider.__class__.__name__}")
except ValueError:
_logger.warning(f"Provider {provider.__class__.__name__} was not registered")
def _get_cache_providers() -> Tuple[CacheProvider, ...]:
return _providers_snapshot
def _has_cache_providers() -> bool:
return bool(_providers_snapshot)
def _clear_cache_providers() -> None:
global _providers_snapshot
with _providers_lock:
_providers.clear()
_providers_snapshot = ()
def _canonicalize(obj: Any) -> Any:
# Convert to canonical JSON-serializable form with deterministic ordering.
# Frozensets have non-deterministic iteration order between Python sessions.
# Raises ValueError for non-cacheable types (Unhashable, unknown) so that
# _serialize_cache_key returns None and external caching is skipped.
if isinstance(obj, frozenset):
return ("__frozenset__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, set):
return ("__set__", sorted(
[_canonicalize(item) for item in obj],
key=lambda x: json.dumps(x, sort_keys=True)
))
elif isinstance(obj, tuple):
return ("__tuple__", [_canonicalize(item) for item in obj])
elif isinstance(obj, list):
return [_canonicalize(item) for item in obj]
elif isinstance(obj, dict):
return {"__dict__": sorted(
[[_canonicalize(k), _canonicalize(v)] for k, v in obj.items()],
key=lambda x: json.dumps(x, sort_keys=True)
)}
elif isinstance(obj, (int, float, str, bool, type(None))):
return (type(obj).__name__, obj)
elif isinstance(obj, bytes):
return ("__bytes__", obj.hex())
else:
raise ValueError(f"Cannot canonicalize type: {type(obj).__name__}")
def _serialize_cache_key(cache_key: Any) -> Optional[str]:
# Returns deterministic SHA256 hex digest, or None on failure.
# Uses JSON (not pickle) because pickle is non-deterministic across sessions.
try:
canonical = _canonicalize(cache_key)
json_str = json.dumps(canonical, sort_keys=True, separators=(',', ':'))
return hashlib.sha256(json_str.encode('utf-8')).hexdigest()
except Exception as e:
_logger.warning(f"Failed to serialize cache key: {e}")
return None
def _contains_self_unequal(obj: Any) -> bool:
# Local cache matches by ==. Values where not (x == x) (NaN, etc.) will
# never hit locally, but serialized form would match externally. Skip these.
try:
if not (obj == obj):
return True
except Exception:
return True
if isinstance(obj, (frozenset, tuple, list, set)):
return any(_contains_self_unequal(item) for item in obj)
if isinstance(obj, dict):
return any(_contains_self_unequal(k) or _contains_self_unequal(v) for k, v in obj.items())
if hasattr(obj, 'value'):
return _contains_self_unequal(obj.value)
return False
def _estimate_value_size(value: CacheValue) -> int:
try:
import torch
except ImportError:
return 0
total = 0
def estimate(obj):
nonlocal total
if isinstance(obj, torch.Tensor):
total += obj.numel() * obj.element_size()
elif isinstance(obj, dict):
for v in obj.values():
estimate(v)
elif isinstance(obj, (list, tuple)):
for item in obj:
estimate(item)
for output in value.outputs:
estimate(output)
return total

View File

@ -1,3 +1,4 @@
import asyncio
import bisect import bisect
import gc import gc
import itertools import itertools
@ -147,13 +148,15 @@ class CacheKeySetInputSignature(CacheKeySet):
self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping) self.get_ordered_ancestry_internal(dynprompt, ancestor_id, ancestors, order_mapping)
class BasicCache: class BasicCache:
def __init__(self, key_class): def __init__(self, key_class, enable_providers=False):
self.key_class = key_class self.key_class = key_class
self.initialized = False self.initialized = False
self.enable_providers = enable_providers
self.dynprompt: DynamicPrompt self.dynprompt: DynamicPrompt
self.cache_key_set: CacheKeySet self.cache_key_set: CacheKeySet
self.cache = {} self.cache = {}
self.subcaches = {} self.subcaches = {}
self._pending_store_tasks: set = set()
async def set_prompt(self, dynprompt, node_ids, is_changed_cache): async def set_prompt(self, dynprompt, node_ids, is_changed_cache):
self.dynprompt = dynprompt self.dynprompt = dynprompt
@ -196,18 +199,138 @@ class BasicCache:
def poll(self, **kwargs): def poll(self, **kwargs):
pass pass
def _set_immediate(self, node_id, value): def get_local(self, node_id):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
def _get_immediate(self, node_id):
if not self.initialized: if not self.initialized:
return None return None
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache: if cache_key in self.cache:
return self.cache[cache_key] return self.cache[cache_key]
else: return None
def set_local(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
async def _set_immediate(self, node_id, value):
assert self.initialized
cache_key = self.cache_key_set.get_data_key(node_id)
self.cache[cache_key] = value
await self._notify_providers_store(node_id, cache_key, value)
async def _get_immediate(self, node_id):
if not self.initialized:
return None
cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key in self.cache:
return self.cache[cache_key]
external_result = await self._check_providers_lookup(node_id, cache_key)
if external_result is not None:
self.cache[cache_key] = external_result
return external_result
return None
async def _notify_providers_store(self, node_id, cache_key, value):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not self.enable_providers:
return
if not _has_cache_providers():
return
if not self._is_external_cacheable_value(value):
return
if _contains_self_unequal(cache_key):
return
context = self._build_context(node_id, cache_key)
if context is None:
return
cache_value = CacheValue(outputs=value.outputs, ui=value.ui)
for provider in _get_cache_providers():
try:
if provider.should_cache(context, cache_value):
task = asyncio.create_task(self._safe_provider_store(provider, context, cache_value))
self._pending_store_tasks.add(task)
task.add_done_callback(self._pending_store_tasks.discard)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on store: {e}")
@staticmethod
async def _safe_provider_store(provider, context, cache_value):
from comfy_execution.cache_provider import _logger
try:
await provider.on_store(context, cache_value)
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} async store error: {e}")
async def _check_providers_lookup(self, node_id, cache_key):
from comfy_execution.cache_provider import (
_has_cache_providers, _get_cache_providers,
CacheValue, _contains_self_unequal, _logger
)
if not self.enable_providers:
return None
if not _has_cache_providers():
return None
if _contains_self_unequal(cache_key):
return None
context = self._build_context(node_id, cache_key)
if context is None:
return None
for provider in _get_cache_providers():
try:
if not provider.should_cache(context):
continue
result = await provider.on_lookup(context)
if result is not None:
if not isinstance(result, CacheValue):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid type")
continue
if not isinstance(result.outputs, (list, tuple)):
_logger.warning(f"Provider {provider.__class__.__name__} returned invalid outputs")
continue
from execution import CacheEntry
return CacheEntry(ui=result.ui, outputs=list(result.outputs))
except Exception as e:
_logger.warning(f"Cache provider {provider.__class__.__name__} error on lookup: {e}")
return None
def _is_external_cacheable_value(self, value):
return hasattr(value, 'outputs') and hasattr(value, 'ui')
def _get_class_type(self, node_id):
if not self.initialized or not self.dynprompt:
return ''
try:
return self.dynprompt.get_node(node_id).get('class_type', '')
except Exception:
return ''
def _build_context(self, node_id, cache_key):
from comfy_execution.cache_provider import CacheContext, _serialize_cache_key, _logger
try:
cache_key_hash = _serialize_cache_key(cache_key)
if cache_key_hash is None:
return None
return CacheContext(
node_id=node_id,
class_type=self._get_class_type(node_id),
cache_key_hash=cache_key_hash,
)
except Exception as e:
_logger.warning(f"Failed to build cache context for node {node_id}: {e}")
return None return None
async def _ensure_subcache(self, node_id, children_ids): async def _ensure_subcache(self, node_id, children_ids):
@ -236,8 +359,8 @@ class BasicCache:
return result return result
class HierarchicalCache(BasicCache): class HierarchicalCache(BasicCache):
def __init__(self, key_class): def __init__(self, key_class, enable_providers=False):
super().__init__(key_class) super().__init__(key_class, enable_providers=enable_providers)
def _get_cache_for(self, node_id): def _get_cache_for(self, node_id):
assert self.dynprompt is not None assert self.dynprompt is not None
@ -257,16 +380,27 @@ class HierarchicalCache(BasicCache):
return None return None
return cache return cache
def get(self, node_id): async def get(self, node_id):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
if cache is None: if cache is None:
return None return None
return cache._get_immediate(node_id) return await cache._get_immediate(node_id)
def set(self, node_id, value): def get_local(self, node_id):
cache = self._get_cache_for(node_id)
if cache is None:
return None
return BasicCache.get_local(cache, node_id)
async def set(self, node_id, value):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
assert cache is not None assert cache is not None
cache._set_immediate(node_id, value) await cache._set_immediate(node_id, value)
def set_local(self, node_id, value):
cache = self._get_cache_for(node_id)
assert cache is not None
BasicCache.set_local(cache, node_id, value)
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
cache = self._get_cache_for(node_id) cache = self._get_cache_for(node_id)
@ -287,18 +421,24 @@ class NullCache:
def poll(self, **kwargs): def poll(self, **kwargs):
pass pass
def get(self, node_id): async def get(self, node_id):
return None return None
def set(self, node_id, value): def get_local(self, node_id):
return None
async def set(self, node_id, value):
pass
def set_local(self, node_id, value):
pass pass
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
return self return self
class LRUCache(BasicCache): class LRUCache(BasicCache):
def __init__(self, key_class, max_size=100): def __init__(self, key_class, max_size=100, enable_providers=False):
super().__init__(key_class) super().__init__(key_class, enable_providers=enable_providers)
self.max_size = max_size self.max_size = max_size
self.min_generation = 0 self.min_generation = 0
self.generation = 0 self.generation = 0
@ -322,18 +462,18 @@ class LRUCache(BasicCache):
del self.children[key] del self.children[key]
self._clean_subcaches() self._clean_subcaches()
def get(self, node_id): async def get(self, node_id):
self._mark_used(node_id) self._mark_used(node_id)
return self._get_immediate(node_id) return await self._get_immediate(node_id)
def _mark_used(self, node_id): def _mark_used(self, node_id):
cache_key = self.cache_key_set.get_data_key(node_id) cache_key = self.cache_key_set.get_data_key(node_id)
if cache_key is not None: if cache_key is not None:
self.used_generation[cache_key] = self.generation self.used_generation[cache_key] = self.generation
def set(self, node_id, value): async def set(self, node_id, value):
self._mark_used(node_id) self._mark_used(node_id)
return self._set_immediate(node_id, value) return await self._set_immediate(node_id, value)
async def ensure_subcache_for(self, node_id, children_ids): async def ensure_subcache_for(self, node_id, children_ids):
# Just uses subcaches for tracking 'live' nodes # Just uses subcaches for tracking 'live' nodes
@ -366,20 +506,20 @@ RAM_CACHE_OLD_WORKFLOW_OOM_MULTIPLIER = 1.3
class RAMPressureCache(LRUCache): class RAMPressureCache(LRUCache):
def __init__(self, key_class): def __init__(self, key_class, enable_providers=False):
super().__init__(key_class, 0) super().__init__(key_class, 0, enable_providers=enable_providers)
self.timestamps = {} self.timestamps = {}
def clean_unused(self): def clean_unused(self):
self._clean_subcaches() self._clean_subcaches()
def set(self, node_id, value): async def set(self, node_id, value):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
super().set(node_id, value) await super().set(node_id, value)
def get(self, node_id): async def get(self, node_id):
self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time() self.timestamps[self.cache_key_set.get_data_key(node_id)] = time.time()
return super().get(node_id) return await super().get(node_id)
def poll(self, ram_headroom): def poll(self, ram_headroom):
def _ram_gb(): def _ram_gb():

View File

@ -204,12 +204,12 @@ class ExecutionList(TopologicalSort):
self.execution_cache_listeners = {} self.execution_cache_listeners = {}
def is_cached(self, node_id): def is_cached(self, node_id):
return self.output_cache.get(node_id) is not None return self.output_cache.get_local(node_id) is not None
def cache_link(self, from_node_id, to_node_id): def cache_link(self, from_node_id, to_node_id):
if to_node_id not in self.execution_cache: if to_node_id not in self.execution_cache:
self.execution_cache[to_node_id] = {} self.execution_cache[to_node_id] = {}
self.execution_cache[to_node_id][from_node_id] = self.output_cache.get(from_node_id) self.execution_cache[to_node_id][from_node_id] = self.output_cache.get_local(from_node_id)
if from_node_id not in self.execution_cache_listeners: if from_node_id not in self.execution_cache_listeners:
self.execution_cache_listeners[from_node_id] = set() self.execution_cache_listeners[from_node_id] = set()
self.execution_cache_listeners[from_node_id].add(to_node_id) self.execution_cache_listeners[from_node_id].add(to_node_id)
@ -221,7 +221,7 @@ class ExecutionList(TopologicalSort):
if value is None: if value is None:
return None return None
#Write back to the main cache on touch. #Write back to the main cache on touch.
self.output_cache.set(from_node_id, value) self.output_cache.set_local(from_node_id, value)
return value return value
def cache_update(self, node_id, value): def cache_update(self, node_id, value):

View File

@ -19,6 +19,7 @@ class EmptyLatentAudio(IO.ComfyNode):
node_id="EmptyLatentAudio", node_id="EmptyLatentAudio",
display_name="Empty Latent Audio", display_name="Empty Latent Audio",
category="latent/audio", category="latent/audio",
essentials_category="Audio",
inputs=[ inputs=[
IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1), IO.Float.Input("seconds", default=47.6, min=1.0, max=1000.0, step=0.1),
IO.Int.Input( IO.Int.Input(
@ -185,6 +186,7 @@ class SaveAudioMP3(IO.ComfyNode):
search_aliases=["export mp3"], search_aliases=["export mp3"],
display_name="Save Audio (MP3)", display_name="Save Audio (MP3)",
category="audio", category="audio",
essentials_category="Audio",
inputs=[ inputs=[
IO.Audio.Input("audio"), IO.Audio.Input("audio"),
IO.String.Input("filename_prefix", default="audio/ComfyUI"), IO.String.Input("filename_prefix", default="audio/ComfyUI"),

View File

@ -3,6 +3,7 @@ from typing_extensions import override
import comfy.model_management import comfy.model_management
from comfy_api.latest import ComfyExtension, io from comfy_api.latest import ComfyExtension, io
import torch
class Canny(io.ComfyNode): class Canny(io.ComfyNode):
@ -29,8 +30,8 @@ class Canny(io.ComfyNode):
@classmethod @classmethod
def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput: def execute(cls, image, low_threshold, high_threshold) -> io.NodeOutput:
output = canny(image.to(comfy.model_management.get_torch_device()).movedim(-1, 1), low_threshold, high_threshold) output = canny(image.to(device=comfy.model_management.get_torch_device(), dtype=torch.float32).movedim(-1, 1), low_threshold, high_threshold)
img_out = output[1].to(comfy.model_management.intermediate_device()).repeat(1, 3, 1, 1).movedim(1, -1) img_out = output[1].to(device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype()).repeat(1, 3, 1, 1).movedim(1, -1)
return io.NodeOutput(img_out) return io.NodeOutput(img_out)

View File

@ -27,8 +27,8 @@ class ContextWindowsManualNode(io.ComfyNode):
io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."), io.Combo.Input("fuse_method", options=comfy.context_windows.ContextFuseMethods.LIST_STATIC, default=comfy.context_windows.ContextFuseMethods.PYRAMID, tooltip="The method to use to fuse the context windows."),
io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."), io.Int.Input("dim", min=0, max=5, default=0, tooltip="The dimension to apply the context windows to."),
io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."), io.Boolean.Input("freenoise", default=False, tooltip="Whether to apply FreeNoise noise shuffling, improves window blending."),
#io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."), io.String.Input("cond_retain_index_list", default="", tooltip="List of latent indices to retain in the conditioning tensors for each window, for example setting this to '0' will use the initial start image for each window."),
#io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."), io.Boolean.Input("split_conds_to_windows", default=False, tooltip="Whether to split multiple conditionings (created by ConditionCombine) to each window based on region index."),
], ],
outputs=[ outputs=[
io.Model.Output(tooltip="The model with context windows applied during sampling."), io.Model.Output(tooltip="The model with context windows applied during sampling."),

View File

@ -6,6 +6,7 @@ import comfy.model_management
import torch import torch
import math import math
import nodes import nodes
import comfy.ldm.flux.math
class CLIPTextEncodeFlux(io.ComfyNode): class CLIPTextEncodeFlux(io.ComfyNode):
@classmethod @classmethod
@ -231,6 +232,68 @@ class Flux2Scheduler(io.ComfyNode):
sigmas = get_schedule(steps, round(seq_len)) sigmas = get_schedule(steps, round(seq_len))
return io.NodeOutput(sigmas) return io.NodeOutput(sigmas)
class KV_Attn_Input:
def __init__(self):
self.cache = {}
def __call__(self, q, k, v, extra_options, **kwargs):
reference_image_num_tokens = extra_options.get("reference_image_num_tokens", [])
if len(reference_image_num_tokens) == 0:
return {}
ref_toks = sum(reference_image_num_tokens)
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
if cache_key in self.cache:
kk, vv = self.cache[cache_key]
self.set_cache = False
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
self.cache[cache_key] = (k[:, :, -ref_toks:].clone(), v[:, :, -ref_toks:].clone())
self.set_cache = True
return {"q": q, "k": k, "v": v}
def cleanup(self):
self.cache = {}
class FluxKVCache(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="FluxKVCache",
display_name="Flux KV Cache",
description="Enables KV Cache optimization for reference images on Flux family models.",
category="",
is_experimental=True,
inputs=[
io.Model.Input("model", tooltip="The model to use KV Cache on."),
],
outputs=[
io.Model.Output(tooltip="The patched model with KV Cache enabled."),
],
)
@classmethod
def execute(cls, model: io.Model.Type) -> io.NodeOutput:
m = model.clone()
input_patch_obj = KV_Attn_Input()
def model_input_patch(inputs):
if len(input_patch_obj.cache) > 0:
ref_image_tokens = sum(inputs["transformer_options"].get("reference_image_num_tokens", []))
if ref_image_tokens > 0:
img = inputs["img"]
inputs["img"] = img[:, :-ref_image_tokens]
return inputs
m.set_model_attn1_patch(input_patch_obj)
m.set_model_post_input_patch(model_input_patch)
if hasattr(model.model.diffusion_model, "params"):
m.add_object_patch("diffusion_model.params.default_ref_method", "index_timestep_zero")
else:
m.add_object_patch("diffusion_model.default_ref_method", "index_timestep_zero")
return io.NodeOutput(m)
class FluxExtension(ComfyExtension): class FluxExtension(ComfyExtension):
@override @override
@ -243,6 +306,7 @@ class FluxExtension(ComfyExtension):
FluxKontextMultiReferenceLatentMethod, FluxKontextMultiReferenceLatentMethod,
EmptyFlux2LatentImage, EmptyFlux2LatentImage,
Flux2Scheduler, Flux2Scheduler,
FluxKVCache,
] ]

View File

@ -14,6 +14,7 @@ class ImageCompare(IO.ComfyNode):
display_name="Image Compare", display_name="Image Compare",
description="Compares two images side by side with a slider.", description="Compares two images side by side with a slider.",
category="image", category="image",
essentials_category="Image Tools",
is_experimental=True, is_experimental=True,
is_output_node=True, is_output_node=True,
inputs=[ inputs=[

View File

@ -58,6 +58,7 @@ class ImageCropV2(IO.ComfyNode):
search_aliases=["trim"], search_aliases=["trim"],
display_name="Image Crop", display_name="Image Crop",
category="image/transform", category="image/transform",
essentials_category="Image Tools",
inputs=[ inputs=[
IO.Image.Input("image"), IO.Image.Input("image"),
IO.BoundingBox.Input("crop_region", component="ImageCrop"), IO.BoundingBox.Input("crop_region", component="ImageCrop"),

View File

@ -3,6 +3,7 @@ import node_helpers
import torch import torch
import comfy.model_management import comfy.model_management
import comfy.model_sampling import comfy.model_sampling
import comfy.samplers
import comfy.utils import comfy.utils
import math import math
import numpy as np import numpy as np
@ -682,6 +683,84 @@ class LTXVSeparateAVLatent(io.ComfyNode):
return io.NodeOutput(video_latent, audio_latent) return io.NodeOutput(video_latent, audio_latent)
class LTXVReferenceAudio(io.ComfyNode):
@classmethod
def define_schema(cls) -> io.Schema:
return io.Schema(
node_id="LTXVReferenceAudio",
display_name="LTXV Reference Audio (ID-LoRA)",
category="conditioning/audio",
description="Set reference audio for ID-LoRA speaker identity transfer. Encodes a reference audio clip into the conditioning and optionally patches the model with identity guidance (extra forward pass without reference, amplifying the speaker identity effect).",
inputs=[
io.Model.Input("model"),
io.Conditioning.Input("positive"),
io.Conditioning.Input("negative"),
io.Audio.Input("reference_audio", tooltip="Reference audio clip whose speaker identity to transfer. ~5 seconds recommended (training duration). Shorter or longer clips may degrade voice identity transfer."),
io.Vae.Input(id="audio_vae", display_name="Audio VAE", tooltip="LTXV Audio VAE for encoding."),
io.Float.Input("identity_guidance_scale", default=3.0, min=0.0, max=100.0, step=0.01, round=0.01, tooltip="Strength of identity guidance. Runs an extra forward pass without reference each step to amplify speaker identity. Set to 0 to disable (no extra pass)."),
io.Float.Input("start_percent", default=0.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="Start of the sigma range where identity guidance is active."),
io.Float.Input("end_percent", default=1.0, min=0.0, max=1.0, step=0.001, advanced=True, tooltip="End of the sigma range where identity guidance is active."),
],
outputs=[
io.Model.Output(),
io.Conditioning.Output(display_name="positive"),
io.Conditioning.Output(display_name="negative"),
],
)
@classmethod
def execute(cls, model, positive, negative, reference_audio, audio_vae, identity_guidance_scale, start_percent, end_percent) -> io.NodeOutput:
# Encode reference audio to latents and patchify
audio_latents = audio_vae.encode(reference_audio)
b, c, t, f = audio_latents.shape
ref_tokens = audio_latents.permute(0, 2, 1, 3).reshape(b, t, c * f)
ref_audio = {"tokens": ref_tokens}
positive = node_helpers.conditioning_set_values(positive, {"ref_audio": ref_audio})
negative = node_helpers.conditioning_set_values(negative, {"ref_audio": ref_audio})
# Patch model with identity guidance
m = model.clone()
scale = identity_guidance_scale
model_sampling = m.get_model_object("model_sampling")
sigma_start = model_sampling.percent_to_sigma(start_percent)
sigma_end = model_sampling.percent_to_sigma(end_percent)
def post_cfg_function(args):
if scale == 0:
return args["denoised"]
sigma = args["sigma"]
sigma_ = sigma[0].item()
if sigma_ > sigma_start or sigma_ < sigma_end:
return args["denoised"]
cond_pred = args["cond_denoised"]
cond = args["cond"]
cfg_result = args["denoised"]
model_options = args["model_options"].copy()
x = args["input"]
# Strip ref_audio from conditioning for the no-reference pass
noref_cond = []
for entry in cond:
new_entry = entry.copy()
mc = new_entry.get("model_conds", {}).copy()
mc.pop("ref_audio", None)
new_entry["model_conds"] = mc
noref_cond.append(new_entry)
(pred_noref,) = comfy.samplers.calc_cond_batch(
args["model"], [noref_cond], x, sigma, model_options
)
return cfg_result + (cond_pred - pred_noref) * scale
m.set_model_sampler_post_cfg_function(post_cfg_function)
return io.NodeOutput(m, positive, negative)
class LtxvExtension(ComfyExtension): class LtxvExtension(ComfyExtension):
@override @override
async def get_node_list(self) -> list[type[io.ComfyNode]]: async def get_node_list(self) -> list[type[io.ComfyNode]]:
@ -697,6 +776,7 @@ class LtxvExtension(ComfyExtension):
LTXVCropGuides, LTXVCropGuides,
LTXVConcatAVLatent, LTXVConcatAVLatent,
LTXVSeparateAVLatent, LTXVSeparateAVLatent,
LTXVReferenceAudio,
] ]

View File

@ -0,0 +1,127 @@
from __future__ import annotations
import hashlib
import os
import numpy as np
import torch
from PIL import Image
import folder_paths
import node_helpers
from comfy_api.latest import ComfyExtension, io, UI
from typing_extensions import override
def hex_to_rgb(hex_color: str) -> tuple[float, float, float]:
hex_color = hex_color.lstrip("#")
if len(hex_color) != 6:
return (0.0, 0.0, 0.0)
r = int(hex_color[0:2], 16) / 255.0
g = int(hex_color[2:4], 16) / 255.0
b = int(hex_color[4:6], 16) / 255.0
return (r, g, b)
class PainterNode(io.ComfyNode):
@classmethod
def define_schema(cls):
return io.Schema(
node_id="Painter",
display_name="Painter",
category="image",
inputs=[
io.Image.Input(
"image",
optional=True,
tooltip="Optional base image to paint over",
),
io.String.Input(
"mask",
default="",
socketless=True,
extra_dict={"widgetType": "PAINTER", "image_upload": True},
),
io.Int.Input(
"width",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.Int.Input(
"height",
default=512,
min=64,
max=4096,
step=64,
socketless=True,
extra_dict={"hidden": True},
),
io.Color.Input("bg_color", default="#000000"),
],
outputs=[
io.Image.Output("IMAGE"),
io.Mask.Output("MASK"),
],
)
@classmethod
def execute(cls, mask, width, height, bg_color="#000000", image=None) -> io.NodeOutput:
if image is not None:
base_image = image[:1]
h, w = base_image.shape[1], base_image.shape[2]
else:
h, w = height, width
r, g, b = hex_to_rgb(bg_color)
base_image = torch.zeros((1, h, w, 3), dtype=torch.float32)
base_image[0, :, :, 0] = r
base_image[0, :, :, 1] = g
base_image[0, :, :, 2] = b
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
painter_img = node_helpers.pillow(Image.open, mask_path)
painter_img = painter_img.convert("RGBA")
if painter_img.size != (w, h):
painter_img = painter_img.resize((w, h), Image.LANCZOS)
painter_np = np.array(painter_img).astype(np.float32) / 255.0
painter_rgb = painter_np[:, :, :3]
painter_alpha = painter_np[:, :, 3:4]
mask_tensor = torch.from_numpy(painter_np[:, :, 3]).unsqueeze(0)
base_np = base_image[0].cpu().numpy()
composited = painter_rgb * painter_alpha + base_np * (1.0 - painter_alpha)
out_image = torch.from_numpy(composited).unsqueeze(0)
else:
mask_tensor = torch.zeros((1, h, w), dtype=torch.float32)
out_image = base_image
return io.NodeOutput(out_image, mask_tensor, ui=UI.PreviewImage(out_image))
@classmethod
def fingerprint_inputs(cls, mask, width, height, bg_color="#000000", image=None):
if mask and mask.strip():
mask_path = folder_paths.get_annotated_filepath(mask)
if os.path.exists(mask_path):
m = hashlib.sha256()
with open(mask_path, "rb") as f:
m.update(f.read())
return m.digest().hex()
return ""
class PainterExtension(ComfyExtension):
@override
async def get_node_list(self):
return [PainterNode]
async def comfy_entrypoint():
return PainterExtension()

View File

@ -21,6 +21,7 @@ class Blend(io.ComfyNode):
node_id="ImageBlend", node_id="ImageBlend",
display_name="Image Blend", display_name="Image Blend",
category="image/postprocessing", category="image/postprocessing",
essentials_category="Image Tools",
inputs=[ inputs=[
io.Image.Input("image1"), io.Image.Input("image1"),
io.Image.Input("image2"), io.Image.Input("image2"),

View File

@ -15,6 +15,7 @@ import comfy.sampler_helpers
import comfy.sd import comfy.sd
import comfy.utils import comfy.utils
import comfy.model_management import comfy.model_management
from comfy.cli_args import args, PerformanceFeature
import comfy_extras.nodes_custom_sampler import comfy_extras.nodes_custom_sampler
import folder_paths import folder_paths
import node_helpers import node_helpers
@ -138,6 +139,7 @@ class TrainSampler(comfy.samplers.Sampler):
training_dtype=torch.bfloat16, training_dtype=torch.bfloat16,
real_dataset=None, real_dataset=None,
bucket_latents=None, bucket_latents=None,
use_grad_scaler=False,
): ):
self.loss_fn = loss_fn self.loss_fn = loss_fn
self.optimizer = optimizer self.optimizer = optimizer
@ -152,6 +154,8 @@ class TrainSampler(comfy.samplers.Sampler):
self.bucket_latents: list[torch.Tensor] | None = ( self.bucket_latents: list[torch.Tensor] | None = (
bucket_latents # list of (Bi, C, Hi, Wi) bucket_latents # list of (Bi, C, Hi, Wi)
) )
# GradScaler for fp16 training
self.grad_scaler = torch.amp.GradScaler() if use_grad_scaler else None
# Precompute bucket offsets and weights for sampling # Precompute bucket offsets and weights for sampling
if bucket_latents is not None: if bucket_latents is not None:
self._init_bucket_data(bucket_latents) self._init_bucket_data(bucket_latents)
@ -204,10 +208,13 @@ class TrainSampler(comfy.samplers.Sampler):
batch_sigmas.requires_grad_(True), batch_sigmas.requires_grad_(True),
**batch_extra_args, **batch_extra_args,
) )
loss = self.loss_fn(x0_pred, x0) loss = self.loss_fn(x0_pred.float(), x0.float())
if bwd: if bwd:
bwd_loss = loss / self.grad_acc bwd_loss = loss / self.grad_acc
bwd_loss.backward() if self.grad_scaler is not None:
self.grad_scaler.scale(bwd_loss).backward()
else:
bwd_loss.backward()
return loss return loss
def _generate_batch_sigmas(self, model_wrap, batch_size, device): def _generate_batch_sigmas(self, model_wrap, batch_size, device):
@ -307,7 +314,10 @@ class TrainSampler(comfy.samplers.Sampler):
) )
total_loss += loss total_loss += loss
total_loss = total_loss / self.grad_acc / len(indicies) total_loss = total_loss / self.grad_acc / len(indicies)
total_loss.backward() if self.grad_scaler is not None:
self.grad_scaler.scale(total_loss).backward()
else:
total_loss.backward()
if self.loss_callback: if self.loss_callback:
self.loss_callback(total_loss.item()) self.loss_callback(total_loss.item())
pbar.set_postfix({"loss": f"{total_loss.item():.4f}"}) pbar.set_postfix({"loss": f"{total_loss.item():.4f}"})
@ -348,12 +358,18 @@ class TrainSampler(comfy.samplers.Sampler):
self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar) self._train_step_multires_mode(model_wrap, cond, extra_args, noisegen, latent_image, dataset_size, pbar)
if (i + 1) % self.grad_acc == 0: if (i + 1) % self.grad_acc == 0:
if self.grad_scaler is not None:
self.grad_scaler.unscale_(self.optimizer)
for param_groups in self.optimizer.param_groups: for param_groups in self.optimizer.param_groups:
for param in param_groups["params"]: for param in param_groups["params"]:
if param.grad is None: if param.grad is None:
continue continue
param.grad.data = param.grad.data.to(param.data.dtype) param.grad.data = param.grad.data.to(param.data.dtype)
self.optimizer.step() if self.grad_scaler is not None:
self.grad_scaler.step(self.optimizer)
self.grad_scaler.update()
else:
self.optimizer.step()
self.optimizer.zero_grad() self.optimizer.zero_grad()
ui_pbar.update(1) ui_pbar.update(1)
torch.cuda.empty_cache() torch.cuda.empty_cache()
@ -1004,9 +1020,9 @@ class TrainLoraNode(io.ComfyNode):
), ),
io.Combo.Input( io.Combo.Input(
"training_dtype", "training_dtype",
options=["bf16", "fp32"], options=["bf16", "fp32", "none"],
default="bf16", default="bf16",
tooltip="The dtype to use for training.", tooltip="The dtype to use for training. 'none' preserves the model's native compute dtype instead of overriding it. For fp16 models, GradScaler is automatically enabled.",
), ),
io.Combo.Input( io.Combo.Input(
"lora_dtype", "lora_dtype",
@ -1035,7 +1051,7 @@ class TrainLoraNode(io.ComfyNode):
io.Boolean.Input( io.Boolean.Input(
"offloading", "offloading",
default=False, default=False,
tooltip="Offload the Model to RAM. Requires Bypass Mode.", tooltip="Offload model weights to CPU during training to save GPU memory.",
), ),
io.Combo.Input( io.Combo.Input(
"existing_lora", "existing_lora",
@ -1120,22 +1136,32 @@ class TrainLoraNode(io.ComfyNode):
# Setup model and dtype # Setup model and dtype
mp = model.clone() mp = model.clone()
dtype = node_helpers.string_to_torch_dtype(training_dtype) use_grad_scaler = False
if training_dtype != "none":
dtype = node_helpers.string_to_torch_dtype(training_dtype)
mp.set_model_compute_dtype(dtype)
else:
# Detect model's native dtype for autocast
model_dtype = mp.model.get_dtype()
if model_dtype == torch.float16:
dtype = torch.float16
use_grad_scaler = True
# Warn about fp16 accumulation instability during training
if PerformanceFeature.Fp16Accumulation in args.fast:
logging.warning(
"WARNING: FP16 model detected with fp16_accumulation enabled. "
"This combination can be numerically unstable during training and may cause NaN values. "
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
)
else:
# For fp8, bf16, or other dtypes, use bf16 autocast
dtype = torch.bfloat16
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype) lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
mp.set_model_compute_dtype(dtype)
if mp.is_dynamic():
if not bypass_mode:
logging.info("Training MP is Dynamic - forcing bypass mode. Start comfy with --highvram to force weight diff mode")
bypass_mode = True
offloading = True
elif offloading:
if not bypass_mode:
logging.info("Training Offload selected - forcing bypass mode. Set bypass = True to remove this message")
# Prepare latents and compute counts # Prepare latents and compute counts
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
latents, num_images, multi_res = _prepare_latents_and_count( latents, num_images, multi_res = _prepare_latents_and_count(
latents, dtype, bucket_mode latents, latents_dtype, bucket_mode
) )
# Validate and expand conditioning # Validate and expand conditioning
@ -1201,6 +1227,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed, seed=seed,
training_dtype=dtype, training_dtype=dtype,
bucket_latents=latents, bucket_latents=latents,
use_grad_scaler=use_grad_scaler,
) )
else: else:
train_sampler = TrainSampler( train_sampler = TrainSampler(
@ -1213,6 +1240,7 @@ class TrainLoraNode(io.ComfyNode):
seed=seed, seed=seed,
training_dtype=dtype, training_dtype=dtype,
real_dataset=latents if multi_res else None, real_dataset=latents if multi_res else None,
use_grad_scaler=use_grad_scaler,
) )
# Setup guider # Setup guider
@ -1337,7 +1365,7 @@ class SaveLoRA(io.ComfyNode):
io.Int.Input( io.Int.Input(
"steps", "steps",
optional=True, optional=True,
tooltip="Optional: The number of steps to LoRA has been trained for, used to name the saved file.", tooltip="Optional: The number of steps the LoRA has been trained for, used to name the saved file.",
), ),
], ],
outputs=[], outputs=[],

View File

@ -1,3 +1,3 @@
# This file is automatically generated by the build process when version is # This file is automatically generated by the build process when version is
# updated in pyproject.toml. # updated in pyproject.toml.
__version__ = "0.16.4" __version__ = "0.18.1"

View File

@ -40,6 +40,7 @@ from comfy_execution.progress import get_progress_state, reset_progress_state, a
from comfy_execution.utils import CurrentNodeContext from comfy_execution.utils import CurrentNodeContext
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
from comfy_api.latest import io, _io from comfy_api.latest import io, _io
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
class ExecutionResult(Enum): class ExecutionResult(Enum):
@ -126,15 +127,15 @@ class CacheSet:
# Performs like the old cache -- dump data ASAP # Performs like the old cache -- dump data ASAP
def init_classic_cache(self): def init_classic_cache(self):
self.outputs = HierarchicalCache(CacheKeySetInputSignature) self.outputs = HierarchicalCache(CacheKeySetInputSignature, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_lru_cache(self, cache_size): def init_lru_cache(self, cache_size):
self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size) self.outputs = LRUCache(CacheKeySetInputSignature, max_size=cache_size, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_ram_cache(self, min_headroom): def init_ram_cache(self, min_headroom):
self.outputs = RAMPressureCache(CacheKeySetInputSignature) self.outputs = RAMPressureCache(CacheKeySetInputSignature, enable_providers=True)
self.objects = HierarchicalCache(CacheKeySetID) self.objects = HierarchicalCache(CacheKeySetID)
def init_null_cache(self): def init_null_cache(self):
@ -418,7 +419,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
inputs = dynprompt.get_node(unique_id)['inputs'] inputs = dynprompt.get_node(unique_id)['inputs']
class_type = dynprompt.get_node(unique_id)['class_type'] class_type = dynprompt.get_node(unique_id)['class_type']
class_def = nodes.NODE_CLASS_MAPPINGS[class_type] class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
cached = caches.outputs.get(unique_id) cached = await caches.outputs.get(unique_id)
if cached is not None: if cached is not None:
if server.client_id is not None: if server.client_id is not None:
cached_ui = cached.ui or {} cached_ui = cached.ui or {}
@ -474,10 +475,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
server.last_node_id = display_node_id server.last_node_id = display_node_id
server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id) server.send_sync("executing", { "node": unique_id, "display_node": display_node_id, "prompt_id": prompt_id }, server.client_id)
obj = caches.objects.get(unique_id) obj = await caches.objects.get(unique_id)
if obj is None: if obj is None:
obj = class_def() obj = class_def()
caches.objects.set(unique_id, obj) await caches.objects.set(unique_id, obj)
if issubclass(class_def, _ComfyNodeInternal): if issubclass(class_def, _ComfyNodeInternal):
lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None lazy_status_present = first_real_override(class_def, "check_lazy_status") is not None
@ -588,7 +589,7 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data) cache_entry = CacheEntry(ui=ui_outputs.get(unique_id), outputs=output_data)
execution_list.cache_update(unique_id, cache_entry) execution_list.cache_update(unique_id, cache_entry)
caches.outputs.set(unique_id, cache_entry) await caches.outputs.set(unique_id, cache_entry)
except comfy.model_management.InterruptProcessingException as iex: except comfy.model_management.InterruptProcessingException as iex:
logging.info("Processing interrupted") logging.info("Processing interrupted")
@ -684,6 +685,19 @@ class PromptExecutor:
} }
self.add_message("execution_error", mes, broadcast=False) self.add_message("execution_error", mes, broadcast=False)
def _notify_prompt_lifecycle(self, event: str, prompt_id: str):
if not _has_cache_providers():
return
for provider in _get_cache_providers():
try:
if event == "start":
provider.on_prompt_start(prompt_id)
elif event == "end":
provider.on_prompt_end(prompt_id)
except Exception as e:
_cache_logger.warning(f"Cache provider {provider.__class__.__name__} error on {event}: {e}")
def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]): def execute(self, prompt, prompt_id, extra_data={}, execute_outputs=[]):
asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs)) asyncio.run(self.execute_async(prompt, prompt_id, extra_data, execute_outputs))
@ -700,66 +714,75 @@ class PromptExecutor:
self.status_messages = [] self.status_messages = []
self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False) self.add_message("execution_start", { "prompt_id": prompt_id}, broadcast=False)
with torch.inference_mode(): self._notify_prompt_lifecycle("start", prompt_id)
dynamic_prompt = DynamicPrompt(prompt)
reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
cached_nodes = [] try:
for node_id in prompt: with torch.inference_mode():
if self.caches.outputs.get(node_id) is not None: dynamic_prompt = DynamicPrompt(prompt)
cached_nodes.append(node_id) reset_progress_state(prompt_id, dynamic_prompt)
add_progress_handler(WebUIProgressHandler(self.server))
is_changed_cache = IsChangedCache(prompt_id, dynamic_prompt, self.caches.outputs)
for cache in self.caches.all:
await cache.set_prompt(dynamic_prompt, prompt.keys(), is_changed_cache)
cache.clean_unused()
comfy.model_management.cleanup_models_gc() node_ids = list(prompt.keys())
self.add_message("execution_cached", cache_results = await asyncio.gather(
{ "nodes": cached_nodes, "prompt_id": prompt_id}, *(self.caches.outputs.get(node_id) for node_id in node_ids)
broadcast=False) )
pending_subgraph_results = {} cached_nodes = [
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results node_id for node_id, result in zip(node_ids, cache_results)
ui_node_outputs = {} if result is not None
executed = set() ]
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
while not execution_list.is_empty(): comfy.model_management.cleanup_models_gc()
node_id, error, ex = await execution_list.stage_node_execution() self.add_message("execution_cached",
if error is not None: { "nodes": cached_nodes, "prompt_id": prompt_id},
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) broadcast=False)
break pending_subgraph_results = {}
pending_async_nodes = {} # TODO - Unify this with pending_subgraph_results
ui_node_outputs = {}
executed = set()
execution_list = ExecutionList(dynamic_prompt, self.caches.outputs)
current_outputs = self.caches.outputs.all_node_ids()
for node_id in list(execute_outputs):
execution_list.add_node(node_id)
assert node_id is not None, "Node ID should not be None at this point" while not execution_list.is_empty():
result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs) node_id, error, ex = await execution_list.stage_node_execution()
self.success = result != ExecutionResult.FAILURE if error is not None:
if result == ExecutionResult.FAILURE: self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex) break
break
elif result == ExecutionResult.PENDING:
execution_list.unstage_node_execution()
else: # result == ExecutionResult.SUCCESS:
execution_list.complete_node_execution()
self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {} assert node_id is not None, "Node ID should not be None at this point"
meta_outputs = {} result, error, ex = await execute(self.server, dynamic_prompt, self.caches, node_id, extra_data, executed, prompt_id, execution_list, pending_subgraph_results, pending_async_nodes, ui_node_outputs)
for node_id, ui_info in ui_node_outputs.items(): self.success = result != ExecutionResult.FAILURE
ui_outputs[node_id] = ui_info["output"] if result == ExecutionResult.FAILURE:
meta_outputs[node_id] = ui_info["meta"] self.handle_execution_error(prompt_id, dynamic_prompt.original_prompt, current_outputs, executed, error, ex)
self.history_result = { break
"outputs": ui_outputs, elif result == ExecutionResult.PENDING:
"meta": meta_outputs, execution_list.unstage_node_execution()
} else: # result == ExecutionResult.SUCCESS:
self.server.last_node_id = None execution_list.complete_node_execution()
if comfy.model_management.DISABLE_SMART_MEMORY: self.caches.outputs.poll(ram_headroom=self.cache_args["ram"])
comfy.model_management.unload_all_models() else:
# Only execute when the while-loop ends without break
self.add_message("execution_success", { "prompt_id": prompt_id }, broadcast=False)
ui_outputs = {}
meta_outputs = {}
for node_id, ui_info in ui_node_outputs.items():
ui_outputs[node_id] = ui_info["output"]
meta_outputs[node_id] = ui_info["meta"]
self.history_result = {
"outputs": ui_outputs,
"meta": meta_outputs,
}
self.server.last_node_id = None
if comfy.model_management.DISABLE_SMART_MEMORY:
comfy.model_management.unload_all_models()
finally:
self._notify_prompt_lifecycle("end", prompt_id)
async def validate_inputs(prompt_id, prompt, item, validated): async def validate_inputs(prompt_id, prompt, item, validated):

View File

@ -206,8 +206,8 @@ import hook_breaker_ac10a0
import comfy.memory_management import comfy.memory_management
import comfy.model_patcher import comfy.model_patcher
if enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl(): if args.enable_dynamic_vram or (enables_dynamic_vram() and comfy.model_management.is_nvidia() and not comfy.model_management.is_wsl()):
if comfy.model_management.torch_version_numeric < (2, 8): if (not args.enable_dynamic_vram) and (comfy.model_management.torch_version_numeric < (2, 8)):
logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows") logging.warning("Unsupported Pytorch detected. DynamicVRAM support requires Pytorch version 2.8 or later. Falling back to legacy ModelPatcher. VRAM estimates may be unreliable especially on Windows")
elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index): elif comfy_aimdo.control.init_device(comfy.model_management.get_torch_device().index):
if args.verbose == 'DEBUG': if args.verbose == 'DEBUG':
@ -471,6 +471,9 @@ if __name__ == "__main__":
if sys.version_info.major == 3 and sys.version_info.minor < 10: if sys.version_info.major == 3 and sys.version_info.minor < 10:
logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.") logging.warning("WARNING: You are using a python version older than 3.10, please upgrade to a newer one. 3.12 and above is recommended.")
if args.disable_dynamic_vram:
logging.warning("Dynamic vram disabled with argument. If you have any issues with dynamic vram enabled please give us a detailed reports as this argument will be removed soon.")
event_loop, _, start_all_func = start_comfyui() event_loop, _, start_all_func = start_comfyui()
try: try:
x = start_all_func() x = start_all_func()

View File

@ -1 +1 @@
comfyui_manager==4.1b2 comfyui_manager==4.1b8

View File

@ -32,7 +32,7 @@ async def cache_control(
) )
if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point: if request.path.endswith(".js") or request.path.endswith(".css") or is_entry_point:
response.headers.setdefault("Cache-Control", "no-cache") response.headers.setdefault("Cache-Control", "no-store")
return response return response
# Early return for non-image files - no cache headers needed # Early return for non-image files - no cache headers needed

View File

@ -81,6 +81,7 @@ class CLIPTextEncode(ComfyNodeABC):
class ConditioningCombine: class ConditioningCombine:
ESSENTIALS_CATEGORY = "Image Generation"
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}} return {"required": {"conditioning_1": ("CONDITIONING", ), "conditioning_2": ("CONDITIONING", )}}
@ -951,7 +952,7 @@ class UNETLoader:
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ), return {"required": { "unet_name": (folder_paths.get_filename_list("diffusion_models"), ),
"weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"],) "weight_dtype": (["default", "fp8_e4m3fn", "fp8_e4m3fn_fast", "fp8_e5m2"], {"advanced": True})
}} }}
RETURN_TYPES = ("MODEL",) RETURN_TYPES = ("MODEL",)
FUNCTION = "load_unet" FUNCTION = "load_unet"
@ -1211,9 +1212,6 @@ class GLIGENTextBoxApply:
return (c, ) return (c, )
class EmptyLatentImage: class EmptyLatentImage:
def __init__(self):
self.device = comfy.model_management.intermediate_device()
@classmethod @classmethod
def INPUT_TYPES(s): def INPUT_TYPES(s):
return { return {
@ -1232,7 +1230,7 @@ class EmptyLatentImage:
SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"] SEARCH_ALIASES = ["empty", "empty latent", "new latent", "create latent", "blank latent", "blank"]
def generate(self, width, height, batch_size=1): def generate(self, width, height, batch_size=1):
latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=self.device) latent = torch.zeros([batch_size, 4, height // 8, width // 8], device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
return ({"samples": latent, "downscale_ratio_spacial": 8}, ) return ({"samples": latent, "downscale_ratio_spacial": 8}, )
@ -1724,6 +1722,8 @@ class LoadImage:
output_masks = [] output_masks = []
w, h = None, None w, h = None, None
dtype = comfy.model_management.intermediate_dtype()
for i in ImageSequence.Iterator(img): for i in ImageSequence.Iterator(img):
i = node_helpers.pillow(ImageOps.exif_transpose, i) i = node_helpers.pillow(ImageOps.exif_transpose, i)
@ -1748,8 +1748,8 @@ class LoadImage:
mask = 1. - torch.from_numpy(mask) mask = 1. - torch.from_numpy(mask)
else: else:
mask = torch.zeros((64,64), dtype=torch.float32, device="cpu") mask = torch.zeros((64,64), dtype=torch.float32, device="cpu")
output_images.append(image) output_images.append(image.to(dtype=dtype))
output_masks.append(mask.unsqueeze(0)) output_masks.append(mask.unsqueeze(0).to(dtype=dtype))
if img.format == "MPO": if img.format == "MPO":
break # ignore all frames except the first one for MPO format break # ignore all frames except the first one for MPO format
@ -1779,6 +1779,7 @@ class LoadImage:
return True return True
class LoadImageMask: class LoadImageMask:
ESSENTIALS_CATEGORY = "Image Tools"
SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"] SEARCH_ALIASES = ["import mask", "alpha mask", "channel mask"]
_color_channels = ["alpha", "red", "green", "blue"] _color_channels = ["alpha", "red", "green", "blue"]
@ -1887,6 +1888,7 @@ class ImageScale:
return (s,) return (s,)
class ImageScaleBy: class ImageScaleBy:
ESSENTIALS_CATEGORY = "Image Tools"
upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"] upscale_methods = ["nearest-exact", "bilinear", "area", "bicubic", "lanczos"]
@classmethod @classmethod
@ -1964,9 +1966,11 @@ class EmptyImage:
CATEGORY = "image" CATEGORY = "image"
def generate(self, width, height, batch_size=1, color=0): def generate(self, width, height, batch_size=1, color=0):
r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF) dtype = comfy.model_management.intermediate_dtype()
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF) device = comfy.model_management.intermediate_device()
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF) r = torch.full([batch_size, height, width, 1], ((color >> 16) & 0xFF) / 0xFF, device=device, dtype=dtype)
g = torch.full([batch_size, height, width, 1], ((color >> 8) & 0xFF) / 0xFF, device=device, dtype=dtype)
b = torch.full([batch_size, height, width, 1], ((color) & 0xFF) / 0xFF, device=device, dtype=dtype)
return (torch.cat((r, g, b), dim=-1), ) return (torch.cat((r, g, b), dim=-1), )
class ImagePadForOutpaint: class ImagePadForOutpaint:
@ -2450,6 +2454,7 @@ async def init_builtin_extra_nodes():
"nodes_nag.py", "nodes_nag.py",
"nodes_sdpose.py", "nodes_sdpose.py",
"nodes_math.py", "nodes_math.py",
"nodes_painter.py",
] ]
import_failed = [] import_failed = []

View File

@ -1,6 +1,6 @@
[project] [project]
name = "ComfyUI" name = "ComfyUI"
version = "0.16.4" version = "0.18.1"
readme = "README.md" readme = "README.md"
license = { file = "LICENSE" } license = { file = "LICENSE" }
requires-python = ">=3.10" requires-python = ">=3.10"

View File

@ -1,5 +1,5 @@
comfyui-frontend-package==1.39.19 comfyui-frontend-package==1.42.8
comfyui-workflow-templates==0.9.18 comfyui-workflow-templates==0.9.26
comfyui-embedded-docs==0.4.3 comfyui-embedded-docs==0.4.3
torch torch
torchsde torchsde
@ -22,8 +22,8 @@ alembic
SQLAlchemy SQLAlchemy
filelock filelock
av>=14.2.0 av>=14.2.0
comfy-kitchen>=0.2.7 comfy-kitchen>=0.2.8
comfy-aimdo>=0.2.9 comfy-aimdo>=0.2.12
requests requests
simpleeval>=1.0.0 simpleeval>=1.0.0
blake3 blake3

View File

@ -35,6 +35,8 @@ from app.frontend_management import FrontendManager, parse_version
from comfy_api.internal import _ComfyNodeInternal from comfy_api.internal import _ComfyNodeInternal
from app.assets.seeder import asset_seeder from app.assets.seeder import asset_seeder
from app.assets.api.routes import register_assets_routes from app.assets.api.routes import register_assets_routes
from app.assets.services.ingest import register_file_in_place
from app.assets.services.asset_management import resolve_hash_to_path
from app.user_manager import UserManager from app.user_manager import UserManager
from app.model_manager import ModelFileManager from app.model_manager import ModelFileManager
@ -310,7 +312,7 @@ class PromptServer():
@routes.get("/") @routes.get("/")
async def get_root(request): async def get_root(request):
response = web.FileResponse(os.path.join(self.web_root, "index.html")) response = web.FileResponse(os.path.join(self.web_root, "index.html"))
response.headers['Cache-Control'] = 'no-cache' response.headers['Cache-Control'] = 'no-store, must-revalidate'
response.headers["Pragma"] = "no-cache" response.headers["Pragma"] = "no-cache"
response.headers["Expires"] = "0" response.headers["Expires"] = "0"
return response return response
@ -419,7 +421,24 @@ class PromptServer():
with open(filepath, "wb") as f: with open(filepath, "wb") as f:
f.write(image.file.read()) f.write(image.file.read())
return web.json_response({"name" : filename, "subfolder": subfolder, "type": image_upload_type}) resp = {"name" : filename, "subfolder": subfolder, "type": image_upload_type}
if args.enable_assets:
try:
tag = image_upload_type if image_upload_type in ("input", "output") else "input"
result = register_file_in_place(abs_path=filepath, name=filename, tags=[tag])
resp["asset"] = {
"id": result.ref.id,
"name": result.ref.name,
"asset_hash": result.asset.hash,
"size": result.asset.size_bytes,
"mime_type": result.asset.mime_type,
"tags": result.tags,
}
except Exception:
logging.warning("Failed to register uploaded image as asset", exc_info=True)
return web.json_response(resp)
else: else:
return web.Response(status=400) return web.Response(status=400)
@ -479,30 +498,43 @@ class PromptServer():
async def view_image(request): async def view_image(request):
if "filename" in request.rel_url.query: if "filename" in request.rel_url.query:
filename = request.rel_url.query["filename"] filename = request.rel_url.query["filename"]
filename, output_dir = folder_paths.annotated_filepath(filename)
if not filename: # The frontend's LoadImage combo widget uses asset_hash values
return web.Response(status=400) # (e.g. "blake3:...") as widget values. When litegraph renders the
# node preview, it constructs /view?filename=<asset_hash>, so this
# endpoint must resolve blake3 hashes to their on-disk file paths.
if filename.startswith("blake3:"):
owner_id = self.user_manager.get_request_user_id(request)
result = resolve_hash_to_path(filename, owner_id=owner_id)
if result is None:
return web.Response(status=404)
file, filename, resolved_content_type = result.abs_path, result.download_name, result.content_type
else:
resolved_content_type = None
filename, output_dir = folder_paths.annotated_filepath(filename)
# validation for security: prevent accessing arbitrary path if not filename:
if filename[0] == '/' or '..' in filename: return web.Response(status=400)
return web.Response(status=400)
if output_dir is None: # validation for security: prevent accessing arbitrary path
type = request.rel_url.query.get("type", "output") if filename[0] == '/' or '..' in filename:
output_dir = folder_paths.get_directory_by_type(type) return web.Response(status=400)
if output_dir is None: if output_dir is None:
return web.Response(status=400) type = request.rel_url.query.get("type", "output")
output_dir = folder_paths.get_directory_by_type(type)
if "subfolder" in request.rel_url.query: if output_dir is None:
full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"]) return web.Response(status=400)
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
return web.Response(status=403)
output_dir = full_output_dir
filename = os.path.basename(filename) if "subfolder" in request.rel_url.query:
file = os.path.join(output_dir, filename) full_output_dir = os.path.join(output_dir, request.rel_url.query["subfolder"])
if os.path.commonpath((os.path.abspath(full_output_dir), output_dir)) != output_dir:
return web.Response(status=403)
output_dir = full_output_dir
filename = os.path.basename(filename)
file = os.path.join(output_dir, filename)
if os.path.isfile(file): if os.path.isfile(file):
if 'preview' in request.rel_url.query: if 'preview' in request.rel_url.query:
@ -562,8 +594,13 @@ class PromptServer():
return web.Response(body=alpha_buffer.read(), content_type='image/png', return web.Response(body=alpha_buffer.read(), content_type='image/png',
headers={"Content-Disposition": f"filename=\"{filename}\""}) headers={"Content-Disposition": f"filename=\"{filename}\""})
else: else:
# Get content type from mimetype, defaulting to 'application/octet-stream' # Use the content type from asset resolution if available,
content_type = mimetypes.guess_type(filename)[0] or 'application/octet-stream' # otherwise guess from the filename.
content_type = (
resolved_content_type
or mimetypes.guess_type(filename)[0]
or 'application/octet-stream'
)
# For security, force certain mimetypes to download instead of display # For security, force certain mimetypes to download instead of display
if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}: if content_type in {'text/html', 'text/html-sandboxed', 'application/xhtml+xml', 'text/javascript', 'text/css'}:

View File

@ -0,0 +1,57 @@
"""Test that Alembic migrations run cleanly on a file-backed SQLite DB.
This catches problems like unnamed FK constraints that prevent batch-mode
drop_constraint from working on real SQLite files (see MB-2).
Migrations 0001 and 0002 are already shipped, so we only exercise
upgrade/downgrade for 0003+.
"""
import os
import pytest
from alembic import command
from alembic.config import Config
# Oldest shipped revision — we upgrade to here as a baseline and never
# downgrade past it.
_BASELINE = "0002_merge_to_asset_references"
def _make_config(db_path: str) -> Config:
root = os.path.join(os.path.dirname(__file__), "../..")
config_path = os.path.abspath(os.path.join(root, "alembic.ini"))
scripts_path = os.path.abspath(os.path.join(root, "alembic_db"))
cfg = Config(config_path)
cfg.set_main_option("script_location", scripts_path)
cfg.set_main_option("sqlalchemy.url", f"sqlite:///{db_path}")
return cfg
@pytest.fixture
def migration_db(tmp_path):
"""Yield an alembic Config pre-upgraded to the baseline revision."""
db_path = str(tmp_path / "test_migration.db")
cfg = _make_config(db_path)
command.upgrade(cfg, _BASELINE)
yield cfg
def test_upgrade_to_head(migration_db):
"""Upgrade from baseline to head must succeed on a file-backed DB."""
command.upgrade(migration_db, "head")
def test_downgrade_to_baseline(migration_db):
"""Upgrade to head then downgrade back to baseline."""
command.upgrade(migration_db, "head")
command.downgrade(migration_db, _BASELINE)
def test_upgrade_downgrade_cycle(migration_db):
"""Full cycle: upgrade → downgrade → upgrade again."""
command.upgrade(migration_db, "head")
command.downgrade(migration_db, _BASELINE)
command.upgrade(migration_db, "head")

View File

@ -10,6 +10,7 @@ from app.assets.database.queries import (
get_asset_by_hash, get_asset_by_hash,
upsert_asset, upsert_asset,
bulk_insert_assets, bulk_insert_assets,
update_asset_hash_and_mime,
) )
@ -142,3 +143,45 @@ class TestBulkInsertAssets:
session.commit() session.commit()
assert session.query(Asset).count() == 200 assert session.query(Asset).count() == 200
class TestMimeTypeImmutability:
"""mime_type on Asset is write-once: set on first ingest, never overwritten."""
@pytest.mark.parametrize(
"initial_mime,second_mime,expected_mime",
[
("image/png", "image/jpeg", "image/png"),
(None, "image/png", "image/png"),
],
ids=["preserves_existing", "fills_null"],
)
def test_upsert_mime_immutability(self, session: Session, initial_mime, second_mime, expected_mime):
h = f"blake3:upsert_{initial_mime}_{second_mime}"
upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=initial_mime)
session.commit()
asset, created, _ = upsert_asset(session, asset_hash=h, size_bytes=100, mime_type=second_mime)
assert created is False
assert asset.mime_type == expected_mime
@pytest.mark.parametrize(
"initial_mime,update_mime,update_hash,expected_mime,expected_hash",
[
(None, "image/png", None, "image/png", "blake3:upd0"),
("image/png", "image/jpeg", None, "image/png", "blake3:upd1"),
("image/png", "image/jpeg", "blake3:upd2_new", "image/png", "blake3:upd2_new"),
],
ids=["fills_null", "preserves_existing", "hash_updates_mime_locked"],
)
def test_update_asset_hash_and_mime_immutability(
self, session: Session, initial_mime, update_mime, update_hash, expected_mime, expected_hash,
):
h = expected_hash.removesuffix("_new")
asset = Asset(hash=h, size_bytes=100, mime_type=initial_mime)
session.add(asset)
session.flush()
update_asset_hash_and_mime(session, asset_id=asset.id, mime_type=update_mime, asset_hash=update_hash)
assert asset.mime_type == expected_mime
assert asset.hash == expected_hash

View File

@ -242,22 +242,24 @@ class TestSetReferencePreview:
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash") preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
session.commit() session.commit()
set_reference_preview(session, reference_id=ref.id, preview_asset_id=preview_asset.id) set_reference_preview(session, reference_id=ref.id, preview_reference_id=preview_ref.id)
session.commit() session.commit()
session.refresh(ref) session.refresh(ref)
assert ref.preview_id == preview_asset.id assert ref.preview_id == preview_ref.id
def test_clears_preview(self, session: Session): def test_clears_preview(self, session: Session):
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash") preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
ref.preview_id = preview_asset.id preview_ref = _make_reference(session, preview_asset, name="preview.png")
ref.preview_id = preview_ref.id
session.commit() session.commit()
set_reference_preview(session, reference_id=ref.id, preview_asset_id=None) set_reference_preview(session, reference_id=ref.id, preview_reference_id=None)
session.commit() session.commit()
session.refresh(ref) session.refresh(ref)
@ -265,15 +267,15 @@ class TestSetReferencePreview:
def test_raises_for_nonexistent_reference(self, session: Session): def test_raises_for_nonexistent_reference(self, session: Session):
with pytest.raises(ValueError, match="not found"): with pytest.raises(ValueError, match="not found"):
set_reference_preview(session, reference_id="nonexistent", preview_asset_id=None) set_reference_preview(session, reference_id="nonexistent", preview_reference_id=None)
def test_raises_for_nonexistent_preview(self, session: Session): def test_raises_for_nonexistent_preview(self, session: Session):
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
session.commit() session.commit()
with pytest.raises(ValueError, match="Preview Asset"): with pytest.raises(ValueError, match="Preview AssetReference"):
set_reference_preview(session, reference_id=ref.id, preview_asset_id="nonexistent") set_reference_preview(session, reference_id=ref.id, preview_reference_id="nonexistent")
class TestInsertReference: class TestInsertReference:
@ -351,13 +353,14 @@ class TestUpdateReferenceTimestamps:
asset = _make_asset(session, "hash1") asset = _make_asset(session, "hash1")
preview_asset = _make_asset(session, "preview_hash") preview_asset = _make_asset(session, "preview_hash")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
session.commit() session.commit()
update_reference_timestamps(session, ref, preview_id=preview_asset.id) update_reference_timestamps(session, ref, preview_id=preview_ref.id)
session.commit() session.commit()
session.refresh(ref) session.refresh(ref)
assert ref.preview_id == preview_asset.id assert ref.preview_id == preview_ref.id
class TestSetReferenceMetadata: class TestSetReferenceMetadata:

View File

@ -20,6 +20,7 @@ def _make_reference(
asset: Asset, asset: Asset,
name: str, name: str,
metadata: dict | None = None, metadata: dict | None = None,
system_metadata: dict | None = None,
) -> AssetReference: ) -> AssetReference:
now = get_utc_now() now = get_utc_now()
ref = AssetReference( ref = AssetReference(
@ -27,6 +28,7 @@ def _make_reference(
name=name, name=name,
asset_id=asset.id, asset_id=asset.id,
user_metadata=metadata, user_metadata=metadata,
system_metadata=system_metadata,
created_at=now, created_at=now,
updated_at=now, updated_at=now,
last_access_time=now, last_access_time=now,
@ -34,8 +36,10 @@ def _make_reference(
session.add(ref) session.add(ref)
session.flush() session.flush()
if metadata: # Build merged projection: {**system_metadata, **user_metadata}
for key, val in metadata.items(): merged = {**(system_metadata or {}), **(metadata or {})}
if merged:
for key, val in merged.items():
for row in convert_metadata_to_rows(key, val): for row in convert_metadata_to_rows(key, val):
meta_row = AssetReferenceMeta( meta_row = AssetReferenceMeta(
asset_reference_id=ref.id, asset_reference_id=ref.id,
@ -182,3 +186,46 @@ class TestMetadataFilterEmptyDict:
refs, _, total = list_references_page(session, metadata_filter={}) refs, _, total = list_references_page(session, metadata_filter={})
assert total == 2 assert total == 2
class TestSystemMetadataProjection:
"""Tests for system_metadata merging into the filter projection."""
def test_system_metadata_keys_are_filterable(self, session: Session):
"""system_metadata keys should appear in the merged projection."""
asset = _make_asset(session, "hash1")
_make_reference(
session, asset, "with_sys",
system_metadata={"source": "scanner"},
)
_make_reference(session, asset, "without_sys")
session.commit()
refs, _, total = list_references_page(
session, metadata_filter={"source": "scanner"}
)
assert total == 1
assert refs[0].name == "with_sys"
def test_user_metadata_overrides_system_metadata(self, session: Session):
"""user_metadata should win when both have the same key."""
asset = _make_asset(session, "hash1")
_make_reference(
session, asset, "overridden",
metadata={"origin": "user_upload"},
system_metadata={"origin": "auto_scan"},
)
session.commit()
# Should match the user value, not the system value
refs, _, total = list_references_page(
session, metadata_filter={"origin": "user_upload"}
)
assert total == 1
assert refs[0].name == "overridden"
# Should NOT match the system value (it was overridden)
refs, _, total = list_references_page(
session, metadata_filter={"origin": "auto_scan"}
)
assert total == 0

View File

@ -11,6 +11,7 @@ from app.assets.services import (
delete_asset_reference, delete_asset_reference,
set_asset_preview, set_asset_preview,
) )
from app.assets.services.asset_management import resolve_hash_to_path
def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset: def _make_asset(session: Session, hash_val: str = "blake3:test", size: int = 1024) -> Asset:
@ -219,31 +220,33 @@ class TestSetAssetPreview:
asset = _make_asset(session, hash_val="blake3:main") asset = _make_asset(session, hash_val="blake3:main")
preview_asset = _make_asset(session, hash_val="blake3:preview") preview_asset = _make_asset(session, hash_val="blake3:preview")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
preview_ref = _make_reference(session, preview_asset, name="preview.png")
ref_id = ref.id ref_id = ref.id
preview_id = preview_asset.id preview_ref_id = preview_ref.id
session.commit() session.commit()
set_asset_preview( set_asset_preview(
reference_id=ref_id, reference_id=ref_id,
preview_asset_id=preview_id, preview_reference_id=preview_ref_id,
) )
# Verify by re-fetching from DB # Verify by re-fetching from DB
session.expire_all() session.expire_all()
updated_ref = session.get(AssetReference, ref_id) updated_ref = session.get(AssetReference, ref_id)
assert updated_ref.preview_id == preview_id assert updated_ref.preview_id == preview_ref_id
def test_clears_preview(self, mock_create_session, session: Session): def test_clears_preview(self, mock_create_session, session: Session):
asset = _make_asset(session) asset = _make_asset(session)
preview_asset = _make_asset(session, hash_val="blake3:preview") preview_asset = _make_asset(session, hash_val="blake3:preview")
ref = _make_reference(session, asset) ref = _make_reference(session, asset)
ref.preview_id = preview_asset.id preview_ref = _make_reference(session, preview_asset, name="preview.png")
ref.preview_id = preview_ref.id
ref_id = ref.id ref_id = ref.id
session.commit() session.commit()
set_asset_preview( set_asset_preview(
reference_id=ref_id, reference_id=ref_id,
preview_asset_id=None, preview_reference_id=None,
) )
# Verify by re-fetching from DB # Verify by re-fetching from DB
@ -263,6 +266,45 @@ class TestSetAssetPreview:
with pytest.raises(PermissionError, match="not owner"): with pytest.raises(PermissionError, match="not owner"):
set_asset_preview( set_asset_preview(
reference_id=ref.id, reference_id=ref.id,
preview_asset_id=None, preview_reference_id=None,
owner_id="user2", owner_id="user2",
) )
class TestResolveHashToPath:
def test_returns_none_for_unknown_hash(self, mock_create_session):
result = resolve_hash_to_path("blake3:" + "a" * 64)
assert result is None
@pytest.mark.parametrize(
"ref_owner, query_owner, expect_found",
[
("user1", "user1", True),
("user1", "user2", False),
("", "anyone", True),
("", "", True),
],
ids=[
"owner_sees_own_ref",
"other_owner_blocked",
"ownerless_visible_to_anyone",
"ownerless_visible_to_empty",
],
)
def test_owner_visibility(
self, ref_owner, query_owner, expect_found,
mock_create_session, session: Session, temp_dir,
):
f = temp_dir / "file.bin"
f.write_bytes(b"data")
asset = _make_asset(session, hash_val="blake3:" + "b" * 64)
ref = _make_reference(session, asset, name="file.bin", owner_id=ref_owner)
ref.file_path = str(f)
session.commit()
result = resolve_hash_to_path(asset.hash, owner_id=query_owner)
if expect_found:
assert result is not None
assert result.abs_path == str(f)
else:
assert result is None

View File

@ -113,11 +113,19 @@ class TestIngestFileFromPath:
file_path = temp_dir / "with_preview.bin" file_path = temp_dir / "with_preview.bin"
file_path.write_bytes(b"data") file_path.write_bytes(b"data")
# Create a preview asset first # Create a preview asset and reference
preview_asset = Asset(hash="blake3:preview", size_bytes=100) preview_asset = Asset(hash="blake3:preview", size_bytes=100)
session.add(preview_asset) session.add(preview_asset)
session.flush()
from app.assets.helpers import get_utc_now
now = get_utc_now()
preview_ref = AssetReference(
asset_id=preview_asset.id, name="preview.png", owner_id="",
created_at=now, updated_at=now, last_access_time=now,
)
session.add(preview_ref)
session.commit() session.commit()
preview_id = preview_asset.id preview_id = preview_ref.id
result = _ingest_file_from_path( result = _ingest_file_from_path(
abs_path=str(file_path), abs_path=str(file_path),

View File

@ -0,0 +1,123 @@
"""Tests for list_tag_histogram service function."""
from sqlalchemy.orm import Session
from app.assets.database.models import Asset, AssetReference
from app.assets.database.queries import ensure_tags_exist, add_tags_to_reference
from app.assets.helpers import get_utc_now
from app.assets.services.tagging import list_tag_histogram
def _make_asset(session: Session, hash_val: str = "blake3:test") -> Asset:
asset = Asset(hash=hash_val, size_bytes=1024)
session.add(asset)
session.flush()
return asset
def _make_reference(
session: Session,
asset: Asset,
name: str = "test",
owner_id: str = "",
) -> AssetReference:
now = get_utc_now()
ref = AssetReference(
owner_id=owner_id,
name=name,
asset_id=asset.id,
created_at=now,
updated_at=now,
last_access_time=now,
)
session.add(ref)
session.flush()
return ref
class TestListTagHistogram:
def test_returns_counts_for_all_tags(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["alpha", "beta"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="r1")
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha", "beta"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="r2")
add_tags_to_reference(session, reference_id=r2.id, tags=["alpha"])
session.commit()
result = list_tag_histogram()
assert result["alpha"] == 2
assert result["beta"] == 1
def test_empty_when_no_assets(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["unused"])
session.commit()
result = list_tag_histogram()
assert result == {}
def test_include_tags_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["models", "loras", "input"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="r1")
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="r2")
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
session.commit()
result = list_tag_histogram(include_tags=["models"])
# Only r1 has "models", so only its tags appear
assert "models" in result
assert "loras" in result
assert "input" not in result
def test_exclude_tags_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["models", "loras", "input"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="r1")
add_tags_to_reference(session, reference_id=r1.id, tags=["models", "loras"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="r2")
add_tags_to_reference(session, reference_id=r2.id, tags=["input"])
session.commit()
result = list_tag_histogram(exclude_tags=["models"])
# r1 excluded, only r2's tags remain
assert "input" in result
assert "loras" not in result
def test_name_contains_filter(self, mock_create_session, session: Session):
ensure_tags_exist(session, ["alpha", "beta"])
a1 = _make_asset(session, "blake3:aaa")
r1 = _make_reference(session, a1, name="my_model.safetensors")
add_tags_to_reference(session, reference_id=r1.id, tags=["alpha"])
a2 = _make_asset(session, "blake3:bbb")
r2 = _make_reference(session, a2, name="picture.png")
add_tags_to_reference(session, reference_id=r2.id, tags=["beta"])
session.commit()
result = list_tag_histogram(name_contains="model")
assert "alpha" in result
assert "beta" not in result
def test_limit_caps_results(self, mock_create_session, session: Session):
tags = [f"tag{i}" for i in range(10)]
ensure_tags_exist(session, tags)
a = _make_asset(session, "blake3:aaa")
r = _make_reference(session, a, name="r1")
add_tags_to_reference(session, reference_id=r.id, tags=tags)
session.commit()
result = list_tag_histogram(limit=3)
assert len(result) == 3

View File

@ -243,6 +243,15 @@ def test_upload_tags_traversal_guard(http: requests.Session, api_base: str):
assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY") assert body["error"]["code"] in ("BAD_REQUEST", "INVALID_BODY")
def test_upload_empty_tags_rejected(http: requests.Session, api_base: str):
files = {"file": ("notags.bin", b"A" * 64, "application/octet-stream")}
form = {"tags": json.dumps([]), "name": "notags.bin", "user_metadata": json.dumps({})}
r = http.post(api_base + "/api/assets", data=form, files=files, timeout=120)
body = r.json()
assert r.status_code == 400
assert body["error"]["code"] == "INVALID_BODY"
@pytest.mark.parametrize("root", ["input", "output"]) @pytest.mark.parametrize("root", ["input", "output"])
def test_duplicate_upload_same_display_name_does_not_clobber( def test_duplicate_upload_same_display_name_does_not_clobber(
root: str, root: str,

View File

@ -0,0 +1,403 @@
"""Tests for external cache provider API."""
import importlib.util
import pytest
from typing import Optional
def _torch_available() -> bool:
"""Check if PyTorch is available."""
return importlib.util.find_spec("torch") is not None
from comfy_execution.cache_provider import (
CacheProvider,
CacheContext,
CacheValue,
register_cache_provider,
unregister_cache_provider,
_get_cache_providers,
_has_cache_providers,
_clear_cache_providers,
_serialize_cache_key,
_contains_self_unequal,
_estimate_value_size,
_canonicalize,
)
class TestCanonicalize:
"""Test _canonicalize function for deterministic ordering."""
def test_frozenset_ordering_is_deterministic(self):
"""Frozensets should produce consistent canonical form regardless of iteration order."""
# Create two frozensets with same content
fs1 = frozenset([("a", 1), ("b", 2), ("c", 3)])
fs2 = frozenset([("c", 3), ("a", 1), ("b", 2)])
result1 = _canonicalize(fs1)
result2 = _canonicalize(fs2)
assert result1 == result2
def test_nested_frozenset_ordering(self):
"""Nested frozensets should also be deterministically ordered."""
inner1 = frozenset([1, 2, 3])
inner2 = frozenset([3, 2, 1])
fs1 = frozenset([("key", inner1)])
fs2 = frozenset([("key", inner2)])
result1 = _canonicalize(fs1)
result2 = _canonicalize(fs2)
assert result1 == result2
def test_dict_ordering(self):
"""Dicts should be sorted by key."""
d1 = {"z": 1, "a": 2, "m": 3}
d2 = {"a": 2, "m": 3, "z": 1}
result1 = _canonicalize(d1)
result2 = _canonicalize(d2)
assert result1 == result2
def test_tuple_preserved(self):
"""Tuples should be marked and preserved."""
t = (1, 2, 3)
result = _canonicalize(t)
assert result[0] == "__tuple__"
def test_list_preserved(self):
"""Lists should be recursively canonicalized."""
lst = [{"b": 2, "a": 1}, frozenset([3, 2, 1])]
result = _canonicalize(lst)
# First element should be canonicalized dict
assert "__dict__" in result[0]
# Second element should be canonicalized frozenset
assert result[1][0] == "__frozenset__"
def test_primitives_include_type(self):
"""Primitive types should include type name for disambiguation."""
assert _canonicalize(42) == ("int", 42)
assert _canonicalize(3.14) == ("float", 3.14)
assert _canonicalize("hello") == ("str", "hello")
assert _canonicalize(True) == ("bool", True)
assert _canonicalize(None) == ("NoneType", None)
def test_int_and_str_distinguished(self):
"""int 7 and str '7' must produce different canonical forms."""
assert _canonicalize(7) != _canonicalize("7")
def test_bytes_converted(self):
"""Bytes should be converted to hex string."""
b = b"\x00\xff"
result = _canonicalize(b)
assert result[0] == "__bytes__"
assert result[1] == "00ff"
def test_set_ordering(self):
"""Sets should be sorted like frozensets."""
s1 = {3, 1, 2}
s2 = {1, 2, 3}
result1 = _canonicalize(s1)
result2 = _canonicalize(s2)
assert result1 == result2
assert result1[0] == "__set__"
def test_unknown_type_raises(self):
"""Unknown types should raise ValueError (fail-closed)."""
class CustomObj:
pass
with pytest.raises(ValueError):
_canonicalize(CustomObj())
def test_object_with_value_attr_raises(self):
"""Objects with .value attribute (Unhashable-like) should raise ValueError."""
class FakeUnhashable:
def __init__(self):
self.value = float('nan')
with pytest.raises(ValueError):
_canonicalize(FakeUnhashable())
class TestSerializeCacheKey:
"""Test _serialize_cache_key for deterministic hashing."""
def test_same_content_same_hash(self):
"""Same content should produce same hash."""
key1 = frozenset([("node_1", frozenset([("input", "value")]))])
key2 = frozenset([("node_1", frozenset([("input", "value")]))])
hash1 = _serialize_cache_key(key1)
hash2 = _serialize_cache_key(key2)
assert hash1 == hash2
def test_different_content_different_hash(self):
"""Different content should produce different hash."""
key1 = frozenset([("node_1", "value_a")])
key2 = frozenset([("node_1", "value_b")])
hash1 = _serialize_cache_key(key1)
hash2 = _serialize_cache_key(key2)
assert hash1 != hash2
def test_returns_hex_string(self):
"""Should return hex string (SHA256 hex digest)."""
key = frozenset([("test", 123)])
result = _serialize_cache_key(key)
assert isinstance(result, str)
assert len(result) == 64 # SHA256 hex digest is 64 chars
def test_complex_nested_structure(self):
"""Complex nested structures should hash deterministically."""
# Note: frozensets can only contain hashable types, so we use
# nested frozensets of tuples to represent dict-like structures
key = frozenset([
("node_1", frozenset([
("input_a", ("tuple", "value")),
("input_b", frozenset([("nested", "dict")])),
])),
("node_2", frozenset([
("param", 42),
])),
])
# Hash twice to verify determinism
hash1 = _serialize_cache_key(key)
hash2 = _serialize_cache_key(key)
assert hash1 == hash2
def test_dict_in_cache_key(self):
"""Dicts passed directly to _serialize_cache_key should work."""
key = {"node_1": {"input": "value"}, "node_2": 42}
hash1 = _serialize_cache_key(key)
hash2 = _serialize_cache_key(key)
assert hash1 == hash2
assert isinstance(hash1, str)
assert len(hash1) == 64
def test_unknown_type_returns_none(self):
"""Non-cacheable types should return None (fail-closed)."""
class CustomObj:
pass
assert _serialize_cache_key(CustomObj()) is None
class TestContainsSelfUnequal:
"""Test _contains_self_unequal utility function."""
def test_nan_float_detected(self):
"""NaN floats should be detected (not equal to itself)."""
assert _contains_self_unequal(float('nan')) is True
def test_regular_float_not_detected(self):
"""Regular floats are equal to themselves."""
assert _contains_self_unequal(3.14) is False
assert _contains_self_unequal(0.0) is False
assert _contains_self_unequal(-1.5) is False
def test_infinity_not_detected(self):
"""Infinity is equal to itself."""
assert _contains_self_unequal(float('inf')) is False
assert _contains_self_unequal(float('-inf')) is False
def test_nan_in_list(self):
"""NaN in list should be detected."""
assert _contains_self_unequal([1, 2, float('nan'), 4]) is True
assert _contains_self_unequal([1, 2, 3, 4]) is False
def test_nan_in_tuple(self):
"""NaN in tuple should be detected."""
assert _contains_self_unequal((1, float('nan'))) is True
assert _contains_self_unequal((1, 2, 3)) is False
def test_nan_in_frozenset(self):
"""NaN in frozenset should be detected."""
assert _contains_self_unequal(frozenset([1, float('nan')])) is True
assert _contains_self_unequal(frozenset([1, 2, 3])) is False
def test_nan_in_dict_value(self):
"""NaN in dict value should be detected."""
assert _contains_self_unequal({"key": float('nan')}) is True
assert _contains_self_unequal({"key": 42}) is False
def test_nan_in_nested_structure(self):
"""NaN in deeply nested structure should be detected."""
nested = {"level1": [{"level2": (1, 2, float('nan'))}]}
assert _contains_self_unequal(nested) is True
def test_non_numeric_types(self):
"""Non-numeric types should not be self-unequal."""
assert _contains_self_unequal("string") is False
assert _contains_self_unequal(None) is False
assert _contains_self_unequal(True) is False
def test_object_with_nan_value_attribute(self):
"""Objects wrapping NaN in .value should be detected."""
class NanWrapper:
def __init__(self):
self.value = float('nan')
assert _contains_self_unequal(NanWrapper()) is True
def test_custom_self_unequal_object(self):
"""Custom objects where not (x == x) should be detected."""
class NeverEqual:
def __eq__(self, other):
return False
assert _contains_self_unequal(NeverEqual()) is True
class TestEstimateValueSize:
"""Test _estimate_value_size utility function."""
def test_empty_outputs(self):
"""Empty outputs should have zero size."""
value = CacheValue(outputs=[])
assert _estimate_value_size(value) == 0
@pytest.mark.skipif(
not _torch_available(),
reason="PyTorch not available"
)
def test_tensor_size_estimation(self):
"""Tensor size should be estimated correctly."""
import torch
# 1000 float32 elements = 4000 bytes
tensor = torch.zeros(1000, dtype=torch.float32)
value = CacheValue(outputs=[[tensor]])
size = _estimate_value_size(value)
assert size == 4000
@pytest.mark.skipif(
not _torch_available(),
reason="PyTorch not available"
)
def test_nested_tensor_in_dict(self):
"""Tensors nested in dicts should be counted."""
import torch
tensor = torch.zeros(100, dtype=torch.float32) # 400 bytes
value = CacheValue(outputs=[[{"samples": tensor}]])
size = _estimate_value_size(value)
assert size == 400
class TestProviderRegistry:
"""Test cache provider registration and retrieval."""
def setup_method(self):
"""Clear providers before each test."""
_clear_cache_providers()
def teardown_method(self):
"""Clear providers after each test."""
_clear_cache_providers()
def test_register_provider(self):
"""Provider should be registered successfully."""
provider = MockCacheProvider()
register_cache_provider(provider)
assert _has_cache_providers() is True
providers = _get_cache_providers()
assert len(providers) == 1
assert providers[0] is provider
def test_unregister_provider(self):
"""Provider should be unregistered successfully."""
provider = MockCacheProvider()
register_cache_provider(provider)
unregister_cache_provider(provider)
assert _has_cache_providers() is False
def test_multiple_providers(self):
"""Multiple providers can be registered."""
provider1 = MockCacheProvider()
provider2 = MockCacheProvider()
register_cache_provider(provider1)
register_cache_provider(provider2)
providers = _get_cache_providers()
assert len(providers) == 2
def test_duplicate_registration_ignored(self):
"""Registering same provider twice should be ignored."""
provider = MockCacheProvider()
register_cache_provider(provider)
register_cache_provider(provider) # Should be ignored
providers = _get_cache_providers()
assert len(providers) == 1
def test_clear_providers(self):
"""_clear_cache_providers should remove all providers."""
provider1 = MockCacheProvider()
provider2 = MockCacheProvider()
register_cache_provider(provider1)
register_cache_provider(provider2)
_clear_cache_providers()
assert _has_cache_providers() is False
assert len(_get_cache_providers()) == 0
class TestCacheContext:
"""Test CacheContext dataclass."""
def test_context_creation(self):
"""CacheContext should be created with all fields."""
context = CacheContext(
node_id="node-456",
class_type="KSampler",
cache_key_hash="a" * 64,
)
assert context.node_id == "node-456"
assert context.class_type == "KSampler"
assert context.cache_key_hash == "a" * 64
class TestCacheValue:
"""Test CacheValue dataclass."""
def test_value_creation(self):
"""CacheValue should be created with outputs."""
outputs = [[{"samples": "tensor_data"}]]
value = CacheValue(outputs=outputs)
assert value.outputs == outputs
class MockCacheProvider(CacheProvider):
"""Mock cache provider for testing."""
def __init__(self):
self.lookups = []
self.stores = []
async def on_lookup(self, context: CacheContext) -> Optional[CacheValue]:
self.lookups.append(context)
return None
async def on_store(self, context: CacheContext, value: CacheValue) -> None:
self.stores.append((context, value))

View File

@ -28,31 +28,31 @@ CACHE_SCENARIOS = [
}, },
# JavaScript/CSS scenarios # JavaScript/CSS scenarios
{ {
"name": "js_no_cache", "name": "js_no_store",
"path": "/script.js", "path": "/script.js",
"status": 200, "status": 200,
"expected_cache": "no-cache", "expected_cache": "no-store",
"should_have_header": True, "should_have_header": True,
}, },
{ {
"name": "css_no_cache", "name": "css_no_store",
"path": "/styles.css", "path": "/styles.css",
"status": 200, "status": 200,
"expected_cache": "no-cache", "expected_cache": "no-store",
"should_have_header": True, "should_have_header": True,
}, },
{ {
"name": "index_json_no_cache", "name": "index_json_no_store",
"path": "/api/index.json", "path": "/api/index.json",
"status": 200, "status": 200,
"expected_cache": "no-cache", "expected_cache": "no-store",
"should_have_header": True, "should_have_header": True,
}, },
{ {
"name": "localized_index_json_no_cache", "name": "localized_index_json_no_store",
"path": "/templates/index.zh.json", "path": "/templates/index.zh.json",
"status": 200, "status": 200,
"expected_cache": "no-cache", "expected_cache": "no-store",
"should_have_header": True, "should_have_header": True,
}, },
# Non-matching files # Non-matching files