mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-15 20:39:48 +08:00
Merge branch 'master' into alexis/add_output_save_nodes
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
Some checks failed
Python Linting / Run Ruff (push) Has been cancelled
Python Linting / Run Pylint (push) Has been cancelled
Build package / Build Test (3.10) (push) Has been cancelled
Build package / Build Test (3.11) (push) Has been cancelled
Build package / Build Test (3.12) (push) Has been cancelled
Build package / Build Test (3.13) (push) Has been cancelled
Build package / Build Test (3.14) (push) Has been cancelled
This commit is contained in:
commit
a4ef68d043
@ -1,5 +1,4 @@
|
||||
As of the time of writing this you need this driver for best results:
|
||||
https://www.amd.com/en/resources/support-articles/release-notes/RN-AMDGPU-WINDOWS-PYTORCH-7-1-1.html
|
||||
As of the time of writing this you need a recent driver. Updating to the latest driver is recommended.
|
||||
|
||||
HOW TO RUN:
|
||||
|
||||
@ -7,9 +6,9 @@ If you have a AMD gpu:
|
||||
|
||||
run_amd_gpu.bat
|
||||
|
||||
If you have memory issues you can try disabling the smart memory management by running comfyui with:
|
||||
If you have memory issues you can try enabling the new dynamic memory management by running comfyui with:
|
||||
|
||||
run_amd_gpu_disable_smart_memory.bat
|
||||
run_amd_gpu_enable_dynamic_vram.bat
|
||||
|
||||
IF YOU GET A RED ERROR IN THE UI MAKE SURE YOU HAVE A MODEL/CHECKPOINT IN: ComfyUI\models\checkpoints
|
||||
|
||||
|
||||
2
.github/workflows/check-line-endings.yml
vendored
2
.github/workflows/check-line-endings.yml
vendored
@ -17,7 +17,7 @@ jobs:
|
||||
- name: Check for Windows line endings (CRLF)
|
||||
run: |
|
||||
# Get the list of changed files in the PR
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }})
|
||||
CHANGED_FILES=$(git diff --name-only ${{ github.event.pull_request.base.sha }}..${{ github.event.pull_request.head.sha }} -- ':!.ci')
|
||||
|
||||
# Flag to track if CRLF is found
|
||||
CRLF_FOUND=false
|
||||
|
||||
12
README.md
12
README.md
@ -364,7 +364,7 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (implies `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
|
||||
|
||||
@ -462,16 +462,6 @@ To use the most up-to-date frontend version:
|
||||
|
||||
This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
|
||||
### Accessing the Legacy Frontend
|
||||
|
||||
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
|
||||
```
|
||||
|
||||
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
|
||||
|
||||
# QA
|
||||
|
||||
### Which GPU should I buy for this?
|
||||
|
||||
39
alembic_db/versions/0004_drop_tag_type.py
Normal file
39
alembic_db/versions/0004_drop_tag_type.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
Drop the vestigial tags.tag_type column.
|
||||
|
||||
tag_type was always "user" in practice — no code path ever set it to anything
|
||||
else (no system/seeded classification was ever wired up) and nothing queried it.
|
||||
The column, its index (ix_tags_tag_type), and the corresponding API field were
|
||||
dead weight, so they are removed.
|
||||
|
||||
Revision ID: 0004_drop_tag_type
|
||||
Revises: 0003_add_metadata_job_id
|
||||
Create Date: 2026-06-03
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0004_drop_tag_type"
|
||||
down_revision = "0003_add_metadata_job_id"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("tags") as batch_op:
|
||||
batch_op.drop_index("ix_tags_tag_type")
|
||||
batch_op.drop_column("tag_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("tags") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"tag_type",
|
||||
sa.String(length=32),
|
||||
nullable=False,
|
||||
server_default="user",
|
||||
)
|
||||
)
|
||||
batch_op.create_index("ix_tags_tag_type", ["tag_type"])
|
||||
@ -39,6 +39,7 @@ from app.assets.services import (
|
||||
update_asset_metadata,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.cursor import InvalidCursorError
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
@ -174,7 +175,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
metadata=result.ref.system_metadata,
|
||||
job_id=result.ref.job_id,
|
||||
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
|
||||
prompt_id=result.ref.job_id, # deprecated alias of job_id, kept for compatibility
|
||||
created_at=result.ref.created_at,
|
||||
updated_at=result.ref.updated_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
@ -211,24 +212,37 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
||||
order_candidate = (q.order or "desc").lower()
|
||||
order = order_candidate if order_candidate in {"asc", "desc"} else "desc"
|
||||
|
||||
result = list_assets_page(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
try:
|
||||
result = list_assets_page(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
after=q.after,
|
||||
)
|
||||
except InvalidCursorError as e:
|
||||
return _build_error_response(400, "INVALID_CURSOR", str(e))
|
||||
|
||||
summaries = [_build_asset_response(item) for item in result.items]
|
||||
|
||||
# has_more semantics differ by mode:
|
||||
# - cursor mode: a non-empty next_cursor means there are more results.
|
||||
# - offset mode: derived from total - (offset + page size).
|
||||
if q.after is not None:
|
||||
has_more = result.next_cursor is not None
|
||||
else:
|
||||
has_more = (q.offset + len(summaries)) < result.total
|
||||
|
||||
payload = schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=result.total,
|
||||
has_more=(q.offset + len(summaries)) < result.total,
|
||||
has_more=has_more,
|
||||
next_cursor=result.next_cursor,
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
@ -519,18 +533,14 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
||||
@_require_assets_feature_enabled
|
||||
async def delete_asset_route(request: web.Request) -> web.Response:
|
||||
reference_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content_param = request.query.get("delete_content")
|
||||
delete_content = (
|
||||
False
|
||||
if delete_content_param is None
|
||||
else delete_content_param.lower() not in {"0", "false", "no"}
|
||||
)
|
||||
|
||||
try:
|
||||
# Deleting an asset is a soft delete of the reference; the underlying
|
||||
# content is preserved (it may be shared with other references).
|
||||
deleted = delete_asset_reference(
|
||||
reference_id=reference_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
delete_content_if_orphan=False,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
@ -575,8 +585,8 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
)
|
||||
|
||||
tags = [
|
||||
schemas_out.TagUsage(name=name, count=count, type=tag_type)
|
||||
for (name, tag_type, count) in rows
|
||||
schemas_out.TagUsage(name=name, count=count)
|
||||
for (name, count) in rows
|
||||
]
|
||||
payload = schemas_out.TagsList(
|
||||
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
||||
|
||||
@ -59,6 +59,11 @@ class ListAssetsQuery(BaseModel):
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
# Opaque keyset cursor. When supplied, `offset` is ignored. Cursor pagination
|
||||
# is supported for sort values `created_at`, `updated_at`, `name`, `size`.
|
||||
# Supplying `after` together with `sort=last_access_time` returns
|
||||
# 400 INVALID_CURSOR; that sort only supports offset/limit.
|
||||
after: str | None = None
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
|
||||
"created_at"
|
||||
|
||||
@ -41,12 +41,13 @@ class AssetsList(BaseModel):
|
||||
assets: list[Asset]
|
||||
total: int
|
||||
has_more: bool
|
||||
# Opaque cursor for the next page. Omitted when there are no more results.
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
type: str
|
||||
|
||||
|
||||
class TagsList(BaseModel):
|
||||
|
||||
@ -227,7 +227,6 @@ class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
|
||||
back_populates="tag",
|
||||
@ -240,7 +239,5 @@ class Tag(Base):
|
||||
overlaps="asset_reference_links,tag_links,tags,asset_reference",
|
||||
)
|
||||
|
||||
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
|
||||
@ -266,9 +266,18 @@ def list_references_page(
|
||||
metadata_filter: dict | None = None,
|
||||
sort: str | None = None,
|
||||
order: str | None = None,
|
||||
after_cursor_value: object | None = None,
|
||||
after_cursor_id: str | None = None,
|
||||
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
|
||||
"""List references with pagination, filtering, and sorting.
|
||||
|
||||
When ``after_cursor_value``/``after_cursor_id`` are supplied the query uses
|
||||
keyset pagination — ``offset`` is ignored and a WHERE clause selects rows
|
||||
strictly after the given ``(sort_col, id)`` position in the active sort
|
||||
direction. The cursor value must already be typed for the column
|
||||
(datetime for time sorts, int for size, str for name); the caller decodes
|
||||
the opaque cursor string and resolves to the typed value.
|
||||
|
||||
Returns (references, tag_map, total_count).
|
||||
"""
|
||||
base = (
|
||||
@ -297,9 +306,31 @@ def list_references_page(
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetReference.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
descending = order == "desc"
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
# Keyset WHERE: (sort_col, id) strictly less-than / greater-than the cursor.
|
||||
# Equivalent to: sort_col <op> v OR (sort_col = v AND id <op> cursor_id).
|
||||
if after_cursor_value is not None and after_cursor_id is not None:
|
||||
if descending:
|
||||
keyset = sa.or_(
|
||||
sort_col < after_cursor_value,
|
||||
sa.and_(sort_col == after_cursor_value, AssetReference.id < after_cursor_id),
|
||||
)
|
||||
else:
|
||||
keyset = sa.or_(
|
||||
sort_col > after_cursor_value,
|
||||
sa.and_(sort_col == after_cursor_value, AssetReference.id > after_cursor_id),
|
||||
)
|
||||
base = base.where(keyset)
|
||||
|
||||
# Secondary ORDER BY id (matching the primary direction) gives the keyset
|
||||
# comparison a deterministic tiebreaker on duplicate sort_col values.
|
||||
id_exp = AssetReference.id.desc() if descending else AssetReference.id.asc()
|
||||
sort_exp = sort_col.desc() if descending else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp, id_exp).limit(limit)
|
||||
if after_cursor_id is None:
|
||||
base = base.offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
|
||||
@ -55,13 +55,11 @@ def validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
|
||||
def ensure_tags_exist(
|
||||
session: Session, names: Iterable[str], tag_type: str = "user"
|
||||
) -> None:
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str]) -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
rows = [{"name": n} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
@ -97,7 +95,7 @@ def set_reference_tags(
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
ensure_tags_exist(session, to_add)
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
@ -142,7 +140,7 @@ def add_tags_to_reference(
|
||||
return AddTagsResult(added=[], already_present=[], total_tags=total)
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
ensure_tags_exist(session, norm)
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
@ -289,7 +287,6 @@ def list_tags_with_usage(
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
@ -331,7 +328,7 @@ def list_tags_with_usage(
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
rows_norm = [(name, int(count or 0)) for (name, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
|
||||
@ -33,6 +33,7 @@ from app.assets.services.file_utils import (
|
||||
verify_file_unchanged,
|
||||
)
|
||||
from app.assets.services.hashing import HashCheckpoint, compute_blake3_hash
|
||||
from app.assets.services.image_dimensions import extract_image_dimensions
|
||||
from app.assets.services.metadata_extract import extract_file_metadata
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
@ -354,7 +355,7 @@ def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
|
||||
return 0
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
ensure_tags_exist(sess, tag_pool)
|
||||
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
|
||||
sess.commit()
|
||||
return result.inserted_refs
|
||||
@ -506,6 +507,10 @@ def enrich_asset(
|
||||
|
||||
if extract_metadata and metadata:
|
||||
system_metadata = metadata.to_user_metadata()
|
||||
if mime_type and mime_type.startswith("image/"):
|
||||
dims = extract_image_dimensions(file_path, mime_type=mime_type)
|
||||
if dims:
|
||||
system_metadata.update(dims)
|
||||
set_reference_system_metadata(session, reference_id, system_metadata)
|
||||
|
||||
if full_hash:
|
||||
|
||||
@ -1,8 +1,19 @@
|
||||
import contextlib
|
||||
import mimetypes
|
||||
import os
|
||||
from datetime import timezone
|
||||
from typing import Sequence
|
||||
|
||||
from app.assets.services.cursor import (
|
||||
CursorPayload,
|
||||
InvalidCursorError,
|
||||
decode_cursor,
|
||||
decode_cursor_int,
|
||||
decode_cursor_time,
|
||||
encode_cursor,
|
||||
encode_cursor_from_time,
|
||||
)
|
||||
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
@ -149,6 +160,16 @@ def delete_asset_reference(
|
||||
owner_id: str,
|
||||
delete_content_if_orphan: bool = True,
|
||||
) -> bool:
|
||||
"""Delete an asset reference.
|
||||
|
||||
With ``delete_content_if_orphan=False`` (a soft delete), the reference is
|
||||
hidden and the underlying content is preserved. With ``True``, the content
|
||||
is also removed once it becomes orphaned.
|
||||
|
||||
Note: the public DELETE /api/assets/{id} endpoint always soft-deletes
|
||||
(passes ``False``); the orphan-reclamation path is intentionally
|
||||
internal-only, retained for a future GC/admin caller.
|
||||
"""
|
||||
with create_session() as session:
|
||||
if not delete_content_if_orphan:
|
||||
# Soft delete: mark the reference as deleted but keep everything
|
||||
@ -242,6 +263,11 @@ def get_asset_by_hash(asset_hash: str) -> AssetData | None:
|
||||
return extract_asset_data(asset)
|
||||
|
||||
|
||||
# Sort fields that support cursor pagination. `last_access_time` is not
|
||||
# in this list — it falls back to offset/limit.
|
||||
_CURSOR_SORT_FIELDS = ("created_at", "updated_at", "name", "size")
|
||||
|
||||
|
||||
def list_assets_page(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@ -252,7 +278,39 @@ def list_assets_page(
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
after: str | None = None,
|
||||
) -> ListAssetsResult:
|
||||
"""List assets with optional cursor pagination.
|
||||
|
||||
When ``after`` is supplied it overrides ``offset``. The cursor's sort field
|
||||
must match ``sort`` and be in the cursor-supported allowlist; mismatches
|
||||
raise InvalidCursorError so the handler can map to 400 INVALID_CURSOR.
|
||||
"""
|
||||
cursor_value: object | None = None
|
||||
cursor_id: str | None = None
|
||||
# Mint next_cursor on every page where the sort is cursor-supported, not
|
||||
# only when the request itself arrived with a cursor. Otherwise a first
|
||||
# request (no `after`) returns next_cursor=None and the client can never
|
||||
# enter cursor mode.
|
||||
mint_cursor = sort in _CURSOR_SORT_FIELDS
|
||||
|
||||
if after is not None:
|
||||
if sort not in _CURSOR_SORT_FIELDS:
|
||||
raise InvalidCursorError(
|
||||
f"cursor pagination is not supported for sort={sort!r}"
|
||||
)
|
||||
payload = decode_cursor(after, _CURSOR_SORT_FIELDS, expected_order=order)
|
||||
if payload.sort_field != sort:
|
||||
raise InvalidCursorError(
|
||||
f"cursor sort field {payload.sort_field!r} does not match request sort {sort!r}"
|
||||
)
|
||||
cursor_value, cursor_id = _resolve_cursor_value(payload), payload.id
|
||||
|
||||
# Over-fetch by one row so we can distinguish "exactly `limit` rows total
|
||||
# remaining" from "more rows past this page" without a second query. Drop
|
||||
# the sentinel before returning.
|
||||
fetch_limit = limit + 1 if mint_cursor else limit
|
||||
|
||||
with create_session() as session:
|
||||
refs, tag_map, total = list_references_page(
|
||||
session,
|
||||
@ -261,12 +319,22 @@ def list_assets_page(
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
limit=fetch_limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
after_cursor_value=cursor_value,
|
||||
after_cursor_id=cursor_id,
|
||||
)
|
||||
|
||||
next_cursor: str | None = None
|
||||
if mint_cursor and len(refs) > limit:
|
||||
# There's at least one more row past this page — mint a cursor from
|
||||
# the last row of the page (i.e. index `limit - 1`, since we
|
||||
# over-fetched), and drop the sentinel.
|
||||
next_cursor = _encode_next_cursor(refs[limit - 1], sort, order)
|
||||
refs = refs[:limit]
|
||||
|
||||
items: list[AssetSummaryData] = []
|
||||
for ref in refs:
|
||||
items.append(
|
||||
@ -277,7 +345,39 @@ def list_assets_page(
|
||||
)
|
||||
)
|
||||
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
return ListAssetsResult(items=items, total=total, next_cursor=next_cursor)
|
||||
|
||||
|
||||
def _resolve_cursor_value(payload: CursorPayload) -> object:
|
||||
"""Map a decoded cursor payload to a column-typed Python value."""
|
||||
if payload.sort_field in ("created_at", "updated_at"):
|
||||
# DB stores naive UTC; strip tzinfo so the comparison binds against a
|
||||
# `TIMESTAMP WITHOUT TIME ZONE` column without an offset shift.
|
||||
return decode_cursor_time(payload).replace(tzinfo=None)
|
||||
if payload.sort_field == "size":
|
||||
return decode_cursor_int(payload)
|
||||
return payload.value # name, str-typed
|
||||
|
||||
|
||||
def _encode_next_cursor(ref, sort: str, order: str) -> str | None:
|
||||
"""Mint a cursor pointing at *ref* for the given sort dimension.
|
||||
|
||||
Returns None when the boundary row carries a NULL sort value (e.g. an asset
|
||||
record whose size_bytes hasn't been backfilled). Continuing pagination
|
||||
across a NULL boundary is undefined under keyset ordering — better to
|
||||
truncate cleanly here than to mint a cursor that mis-positions.
|
||||
"""
|
||||
if sort == "name":
|
||||
return encode_cursor("name", ref.name, ref.id, order=order)
|
||||
if sort == "size":
|
||||
if ref.asset is None or ref.asset.size_bytes is None:
|
||||
return None
|
||||
return encode_cursor("size", str(ref.asset.size_bytes), ref.id, order=order)
|
||||
# created_at / updated_at — DB datetimes are naive UTC; attach tz before encoding.
|
||||
value = ref.created_at if sort == "created_at" else ref.updated_at
|
||||
if value is None:
|
||||
return None
|
||||
return encode_cursor_from_time(sort, value.replace(tzinfo=timezone.utc), ref.id, order=order)
|
||||
|
||||
|
||||
def resolve_hash_to_path(
|
||||
|
||||
213
app/assets/services/cursor.py
Normal file
213
app/assets/services/cursor.py
Normal file
@ -0,0 +1,213 @@
|
||||
"""Opaque keyset-pagination cursor for /api/assets.
|
||||
|
||||
Payload JSON uses short keys to keep the encoded length small:
|
||||
|
||||
{"s": <sort_field>, "v": <value>, "id": <id>, "o": <order>}
|
||||
|
||||
The `o` key binds the cursor to the sort direction it was minted under,
|
||||
so replaying a `desc` cursor against an `asc` request fails with
|
||||
``INVALID_CURSOR`` rather than silently walking the wrong direction.
|
||||
`o` is mandatory on every payload — a cursor without it is rejected as
|
||||
malformed.
|
||||
|
||||
Encoding is base64url with no padding. Cursors are opaque tokens: the
|
||||
payload format is internal to this server, and clients must treat a
|
||||
cursor as a black box handed back via `next_cursor`. No byte-level
|
||||
compatibility with any other implementation is required.
|
||||
|
||||
Time values are serialized as Unix microseconds (UTC) — microsecond
|
||||
precision is sufficient to round-trip the timestamps stored by the
|
||||
database without rounding rows in the same millisecond bucket.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterable, Optional
|
||||
|
||||
|
||||
class InvalidCursorError(ValueError):
|
||||
"""Raised on a malformed, oversized, or unsupported-sort-field cursor.
|
||||
|
||||
Map to a 400 response with code ``INVALID_CURSOR`` at the handler.
|
||||
"""
|
||||
|
||||
|
||||
# Wire-format length caps. Cursors are user-controlled, so caps protect the
|
||||
# decode path from oversized allocations and downstream SQL predicates from
|
||||
# unbounded strings.
|
||||
#
|
||||
# MAX_CURSOR_VALUE_LENGTH is 512 to fit the `AssetReference.name` column max
|
||||
# (`String(512)`) — otherwise a long-named asset would mint a cursor the same
|
||||
# server then refuses on the next request.
|
||||
#
|
||||
# MAX_ENCODED_CURSOR_LENGTH is the decode-path guard, sized comfortably above
|
||||
# the largest cursor the per-field caps can produce. Worst case is value + id
|
||||
# at their caps with every character JSON-escaping to the six-byte `\uXXXX`
|
||||
# form (control characters), which is ~5.2 KB once base64url-encoded. At 8192
|
||||
# the encoder can never mint a cursor that exceeds it, so a freshly minted
|
||||
# cursor always decodes on the next request and there is no user-visible
|
||||
# "cursor too long" failure.
|
||||
MAX_ENCODED_CURSOR_LENGTH = 8192
|
||||
MAX_CURSOR_VALUE_LENGTH = 512
|
||||
MAX_CURSOR_ID_LENGTH = 128
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CursorPayload:
|
||||
sort_field: str
|
||||
value: str
|
||||
id: str
|
||||
order: str
|
||||
|
||||
|
||||
_VALID_ORDERS = ("asc", "desc")
|
||||
|
||||
|
||||
def encode_cursor(sort_field: str, value: str, id: str, order: str = "desc") -> str:
|
||||
"""Encode a cursor payload as a base64url (no-padding) string.
|
||||
|
||||
`order` binds the cursor to the sort direction it was minted under so a
|
||||
later request with a flipped `order` query parameter is rejected with
|
||||
``INVALID_CURSOR`` rather than silently walking the wrong direction.
|
||||
"""
|
||||
if order not in _VALID_ORDERS:
|
||||
raise InvalidCursorError(f"order must be one of {_VALID_ORDERS}, got {order!r}")
|
||||
# Symmetric input validation: the encoder must reject anything the
|
||||
# decoder rejects, or the same server will mint cursors it then 400s on
|
||||
# the next request.
|
||||
if not id:
|
||||
raise InvalidCursorError("id must be non-empty")
|
||||
if len(id) > MAX_CURSOR_ID_LENGTH:
|
||||
raise InvalidCursorError("id exceeds maximum length")
|
||||
if len(value) > MAX_CURSOR_VALUE_LENGTH:
|
||||
raise InvalidCursorError("value exceeds maximum length")
|
||||
payload = {"s": sort_field, "v": value, "id": id, "o": order}
|
||||
raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
|
||||
# No mint-time length guard is needed: the per-field caps above bound the
|
||||
# encoded length well below MAX_ENCODED_CURSOR_LENGTH (see its definition),
|
||||
# so the encoder can never produce a cursor the decode path would reject.
|
||||
return base64.urlsafe_b64encode(raw.encode("utf-8")).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def encode_cursor_from_time(sort_field: str, t: datetime, id: str, order: str = "desc") -> str:
|
||||
"""Encode a time-typed cursor at Unix microsecond precision.
|
||||
|
||||
Accepts an aware datetime (any timezone) and normalizes to UTC. Naive
|
||||
datetimes are rejected so callers can't accidentally encode the local
|
||||
wall-clock value of a UTC-stored timestamp.
|
||||
"""
|
||||
if t.tzinfo is None:
|
||||
raise ValueError("encode_cursor_from_time requires an aware datetime")
|
||||
micros = _datetime_to_unix_micros(t.astimezone(timezone.utc))
|
||||
return encode_cursor(sort_field, str(micros), id, order=order)
|
||||
|
||||
|
||||
def decode_cursor(
|
||||
cursor: str,
|
||||
allowed_sort_fields: Iterable[str],
|
||||
expected_order: str | None = None,
|
||||
) -> CursorPayload:
|
||||
"""Parse an opaque cursor.
|
||||
|
||||
``allowed_sort_fields`` is the endpoint's accepted sort-field list — a
|
||||
cursor carrying a field outside this set is rejected so a cursor minted
|
||||
for one column can't be replayed against another (e.g. a ``created_at``
|
||||
timestamp string compared against a ``name`` column).
|
||||
|
||||
``expected_order`` (``"asc"``/``"desc"``), when supplied, must match the
|
||||
payload's ``o`` field. ``o`` is required on every payload; a cursor
|
||||
missing it is rejected as malformed.
|
||||
|
||||
Passing no allowed fields rejects every cursor.
|
||||
"""
|
||||
if len(cursor) > MAX_ENCODED_CURSOR_LENGTH:
|
||||
raise InvalidCursorError("cursor exceeds maximum length")
|
||||
|
||||
try:
|
||||
# urlsafe_b64decode requires correct padding; we strip on encode, so
|
||||
# restore the trailing '=' pad here.
|
||||
padding = "=" * (-len(cursor) % 4)
|
||||
raw = base64.urlsafe_b64decode(cursor + padding)
|
||||
except (ValueError, base64.binascii.Error) as e:
|
||||
raise InvalidCursorError(f"encoding: {e}") from e
|
||||
|
||||
try:
|
||||
decoded = json.loads(raw)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
raise InvalidCursorError(f"payload: {e}") from e
|
||||
|
||||
if not isinstance(decoded, dict):
|
||||
raise InvalidCursorError("payload: expected object")
|
||||
|
||||
sort_field = decoded.get("s")
|
||||
value = decoded.get("v")
|
||||
id = decoded.get("id")
|
||||
order = decoded.get("o")
|
||||
|
||||
if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str):
|
||||
raise InvalidCursorError("payload: missing or non-string s/v/id")
|
||||
|
||||
if id == "":
|
||||
raise InvalidCursorError("missing id")
|
||||
if len(id) > MAX_CURSOR_ID_LENGTH:
|
||||
raise InvalidCursorError("id exceeds maximum length")
|
||||
if len(value) > MAX_CURSOR_VALUE_LENGTH:
|
||||
raise InvalidCursorError("value exceeds maximum length")
|
||||
|
||||
if sort_field not in allowed_sort_fields:
|
||||
raise InvalidCursorError(f"unsupported sort field {sort_field!r}")
|
||||
|
||||
if not isinstance(order, str):
|
||||
raise InvalidCursorError("missing or non-string o")
|
||||
if order not in _VALID_ORDERS:
|
||||
raise InvalidCursorError(f"unsupported order {order!r}")
|
||||
if expected_order is not None and order != expected_order:
|
||||
raise InvalidCursorError(
|
||||
f"cursor order {order!r} does not match request order {expected_order!r}"
|
||||
)
|
||||
|
||||
return CursorPayload(sort_field=sort_field, value=value, id=id, order=order)
|
||||
|
||||
|
||||
def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime:
|
||||
"""Parse a time-typed cursor value as Unix microseconds, returning UTC."""
|
||||
if payload is None:
|
||||
raise InvalidCursorError("nil cursor payload")
|
||||
try:
|
||||
micros = int(payload.value)
|
||||
except ValueError as e:
|
||||
raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e
|
||||
try:
|
||||
return _unix_micros_to_datetime(micros)
|
||||
except (OverflowError, OSError, ValueError) as e:
|
||||
# Crafted out-of-range microseconds (e.g. > datetime.MAX_YEAR) blow up
|
||||
# in fromtimestamp / datetime construction. Map to 400, not 500.
|
||||
raise InvalidCursorError(f"value is out of representable range: {e}") from e
|
||||
|
||||
|
||||
def decode_cursor_int(payload: Optional[CursorPayload]) -> int:
|
||||
"""Parse a cursor value as a base-10 integer."""
|
||||
if payload is None:
|
||||
raise InvalidCursorError("nil cursor payload")
|
||||
try:
|
||||
return int(payload.value)
|
||||
except ValueError as e:
|
||||
raise InvalidCursorError(f"value is not a valid integer: {e}") from e
|
||||
|
||||
|
||||
_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _datetime_to_unix_micros(t: datetime) -> int:
|
||||
"""Convert an aware UTC datetime to Unix microseconds (integer math)."""
|
||||
delta = t - _EPOCH
|
||||
return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds
|
||||
|
||||
|
||||
def _unix_micros_to_datetime(micros: int) -> datetime:
|
||||
"""Convert Unix microseconds to a UTC datetime, preserving precision."""
|
||||
seconds, micro_remainder = divmod(micros, 1_000_000)
|
||||
return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder)
|
||||
63
app/assets/services/image_dimensions.py
Normal file
63
app/assets/services/image_dimensions.py
Normal file
@ -0,0 +1,63 @@
|
||||
"""Image dimension extraction for asset ingest.
|
||||
|
||||
Reads only the image header via Pillow to capture width/height cheaply,
|
||||
without a full pixel decode. Returns a metadata dict suitable for merging
|
||||
into ``AssetReference.system_metadata``.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def extract_image_dimensions(
|
||||
file_path: str, mime_type: str | None = None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Extract image dimensions for the file at ``file_path``.
|
||||
|
||||
Args:
|
||||
file_path: Absolute path to a file on disk.
|
||||
mime_type: Optional MIME type hint. When provided and not prefixed
|
||||
with ``image/``, extraction is skipped without touching the file.
|
||||
|
||||
Returns:
|
||||
``{"kind": "image", "width": W, "height": H}`` when the file is a
|
||||
recognizable image with positive dimensions, otherwise ``None``.
|
||||
|
||||
The dict shape is intended to be merged into ``system_metadata`` so the
|
||||
asset response surfaces ``metadata.kind`` plus dimension fields for image
|
||||
assets. Forward-compatible: future media kinds (e.g. ``"video"`` with
|
||||
duration/fps) can extend this shape without schema changes.
|
||||
"""
|
||||
if mime_type is not None and not mime_type.startswith("image/"):
|
||||
return None
|
||||
|
||||
try:
|
||||
from PIL import Image, UnidentifiedImageError
|
||||
except ImportError:
|
||||
logger.debug(
|
||||
"Pillow not available; skipping image dimension extraction for %s",
|
||||
file_path,
|
||||
)
|
||||
return None
|
||||
|
||||
try:
|
||||
with Image.open(file_path) as img:
|
||||
width, height = img.size
|
||||
except (OSError, UnidentifiedImageError, ValueError) as exc:
|
||||
logger.debug(
|
||||
"Failed to read image dimensions from %s: %s", file_path, exc
|
||||
)
|
||||
return None
|
||||
|
||||
if (
|
||||
not isinstance(width, int)
|
||||
or not isinstance(height, int)
|
||||
or width <= 0
|
||||
or height <= 0
|
||||
):
|
||||
return None
|
||||
|
||||
return {"kind": "image", "width": width, "height": height}
|
||||
@ -17,9 +17,11 @@ from app.assets.database.queries import (
|
||||
get_reference_by_file_path,
|
||||
get_reference_tags,
|
||||
get_or_create_reference,
|
||||
list_references_by_asset_id,
|
||||
reference_exists,
|
||||
remove_missing_tag_for_asset_id,
|
||||
set_reference_metadata,
|
||||
set_reference_system_metadata,
|
||||
set_reference_tags,
|
||||
update_asset_hash_and_mime,
|
||||
upsert_asset,
|
||||
@ -29,6 +31,7 @@ from app.assets.database.queries import (
|
||||
from app.assets.helpers import get_utc_now, normalize_tags
|
||||
from app.assets.services.bulk_ingest import batch_insert_seed_assets
|
||||
from app.assets.services.file_utils import get_size_and_mtime_ns
|
||||
from app.assets.services.image_dimensions import extract_image_dimensions
|
||||
from app.assets.services.path_utils import (
|
||||
compute_relative_filename,
|
||||
get_name_and_tags_from_asset_path,
|
||||
@ -118,6 +121,14 @@ def _ingest_file_from_path(
|
||||
user_metadata=user_metadata,
|
||||
)
|
||||
|
||||
_maybe_store_image_dimensions(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
file_path=locator,
|
||||
mime_type=mime_type,
|
||||
current_system_metadata=ref.system_metadata,
|
||||
)
|
||||
|
||||
try:
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
except Exception:
|
||||
@ -288,6 +299,13 @@ def _register_existing_asset(
|
||||
user_metadata=new_meta,
|
||||
)
|
||||
|
||||
_backfill_image_dimensions_from_siblings(
|
||||
session,
|
||||
asset_id=asset.id,
|
||||
new_reference_id=ref.id,
|
||||
current_system_metadata=ref.system_metadata,
|
||||
)
|
||||
|
||||
if tags is not None:
|
||||
set_reference_tags(
|
||||
session,
|
||||
@ -334,6 +352,87 @@ def _update_metadata_with_filename(
|
||||
)
|
||||
|
||||
|
||||
_IMAGE_DIMENSION_KEYS = ("kind", "width", "height")
|
||||
|
||||
|
||||
def _maybe_store_image_dimensions(
|
||||
session: Session,
|
||||
reference_id: str,
|
||||
file_path: str,
|
||||
mime_type: str | None,
|
||||
current_system_metadata: dict | None,
|
||||
) -> None:
|
||||
"""Populate ``kind``/``width``/``height`` on system_metadata for image refs.
|
||||
|
||||
Non-image MIME types are a no-op. Pre-existing keys (e.g. enricher-written
|
||||
safetensors metadata, download provenance) are preserved by merge.
|
||||
"""
|
||||
if not mime_type or not mime_type.startswith("image/"):
|
||||
return
|
||||
|
||||
dims = extract_image_dimensions(file_path, mime_type=mime_type)
|
||||
if not dims:
|
||||
return
|
||||
|
||||
current = current_system_metadata or {}
|
||||
merged = dict(current)
|
||||
merged.update(dims)
|
||||
if merged != current:
|
||||
set_reference_system_metadata(
|
||||
session,
|
||||
reference_id=reference_id,
|
||||
system_metadata=merged,
|
||||
)
|
||||
|
||||
|
||||
def _backfill_image_dimensions_from_siblings(
|
||||
session: Session,
|
||||
asset_id: str,
|
||||
new_reference_id: str,
|
||||
current_system_metadata: dict | None,
|
||||
) -> None:
|
||||
"""Copy image dimension keys from any sibling reference of the same asset.
|
||||
|
||||
The from-hash path doesn't read the file bytes, so dimensions can't be
|
||||
extracted there directly. When another reference of the same asset already
|
||||
carries image dimensions, copy them onto the new reference so consumers
|
||||
see consistent metadata regardless of how the asset was registered.
|
||||
|
||||
Best-effort: missing siblings, non-image siblings, or absent dimension
|
||||
keys leave the target reference unchanged.
|
||||
"""
|
||||
current = current_system_metadata or {}
|
||||
if current.get("kind") == "image" and "width" in current and "height" in current:
|
||||
return
|
||||
|
||||
for sibling in list_references_by_asset_id(session, asset_id):
|
||||
if sibling.id == new_reference_id:
|
||||
continue
|
||||
meta = sibling.system_metadata or {}
|
||||
if meta.get("kind") != "image":
|
||||
continue
|
||||
width = meta.get("width")
|
||||
height = meta.get("height")
|
||||
if (
|
||||
type(width) is not int
|
||||
or type(height) is not int
|
||||
or width <= 0
|
||||
or height <= 0
|
||||
):
|
||||
continue
|
||||
merged = dict(current)
|
||||
merged["kind"] = "image"
|
||||
merged["width"] = width
|
||||
merged["height"] = height
|
||||
if merged != current:
|
||||
set_reference_system_metadata(
|
||||
session,
|
||||
reference_id=new_reference_id,
|
||||
system_metadata=merged,
|
||||
)
|
||||
return
|
||||
|
||||
|
||||
def _sanitize_filename(name: str | None, fallback: str) -> str:
|
||||
n = os.path.basename((name or "").strip() or fallback)
|
||||
return n if n else fallback
|
||||
|
||||
@ -56,7 +56,6 @@ class IngestResult:
|
||||
|
||||
class TagUsage(NamedTuple):
|
||||
name: str
|
||||
tag_type: str
|
||||
count: int
|
||||
|
||||
|
||||
@ -71,6 +70,7 @@ class AssetSummaryData:
|
||||
class ListAssetsResult:
|
||||
items: list[AssetSummaryData]
|
||||
total: int
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@ -75,7 +75,7 @@ def list_tags(
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||
return [TagUsage(name, count) for name, count in rows], total
|
||||
|
||||
|
||||
def list_tag_histogram(
|
||||
|
||||
@ -115,6 +115,7 @@ cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metav
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--high-ram", action="store_true", help="Can improve performance slightly on high RAM or on systems where pagefile use is preferred over model loading.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@ -133,7 +134,7 @@ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
|
||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||
manager_group = parser.add_mutually_exclusive_group()
|
||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager. Implies --enable-manager.")
|
||||
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
@ -166,6 +167,8 @@ class PerformanceFeature(enum.Enum):
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--debug-hang", action="store_true", help="Enable stack trace dumps on Ctrl-C for debugging hangs.")
|
||||
|
||||
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
@ -247,6 +250,9 @@ else:
|
||||
if args.cache_ram is not None and len(args.cache_ram) > 2:
|
||||
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
|
||||
|
||||
if args.high_ram:
|
||||
args.cache_classic = True
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
|
||||
@ -256,6 +262,10 @@ if args.disable_auto_launch:
|
||||
if args.force_fp16:
|
||||
args.fp16_unet = True
|
||||
|
||||
# '--enable-manager-legacy-ui' is meaningless unless the manager is enabled, so imply '--enable-manager'.
|
||||
if args.enable_manager_legacy_ui:
|
||||
args.enable_manager = True
|
||||
|
||||
|
||||
# '--fast' is not provided, use an empty set
|
||||
if args.fast is None:
|
||||
|
||||
@ -1,7 +1,13 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.text_encoders.bert import BertAttention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.ldm.depth_anything_3.reference_view_selector import (
|
||||
select_reference_view, reorder_by_reference, restore_original_order,
|
||||
THRESH_FOR_REF_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
class Dino2AttentionOutput(torch.nn.Module):
|
||||
@ -14,13 +20,41 @@ class Dino2AttentionOutput(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2AttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations,
|
||||
qk_norm=False):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.head_dim = embed_dim // heads
|
||||
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
if qk_norm:
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
def forward(self, x, mask, optimized_attention, pos=None, rope=None):
|
||||
# Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions).
|
||||
if self.q_norm is None and rope is None:
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
|
||||
# DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE.
|
||||
attn = self.attention
|
||||
B, N, C = x.shape
|
||||
h = self.heads
|
||||
d = self.head_dim
|
||||
q = attn.query(x).view(B, N, h, d).transpose(1, 2)
|
||||
k = attn.key(x).view(B, N, h, d).transpose(1, 2)
|
||||
v = attn.value(x).view(B, N, h, d).transpose(1, 2)
|
||||
if self.q_norm is not None:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if rope is not None and pos is not None:
|
||||
q = rope(q, pos)
|
||||
k = rope(k, pos)
|
||||
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
|
||||
return self.output(out)
|
||||
|
||||
|
||||
class LayerScale(torch.nn.Module):
|
||||
@ -64,9 +98,11 @@ class SwiGLUFFN(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn,
|
||||
qk_norm=False):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations,
|
||||
qk_norm=qk_norm)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
if use_swiglu_ffn:
|
||||
@ -76,19 +112,90 @@ class Dino2Block(torch.nn.Module):
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, optimized_attention):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||
def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention,
|
||||
pos=pos, rope=rope))
|
||||
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2D Rotary position embedding (DA3 extension)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _PositionGetter:
|
||||
"""Cache (h, w) -> flat (y, x) position grid used to feed ``rope``."""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict = {}
|
||||
|
||||
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
|
||||
key = (height, width, device)
|
||||
if key not in self._cache:
|
||||
y = torch.arange(height, device=device)
|
||||
x = torch.arange(width, device=device)
|
||||
self._cache[key] = torch.cartesian_prod(y, x)
|
||||
cached = self._cache[key]
|
||||
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
||||
|
||||
|
||||
class RotaryPositionEmbedding2D(torch.nn.Module):
|
||||
"""2D RoPE used by DA3-Small/Base. No learnable parameters."""
|
||||
|
||||
def __init__(self, frequency: float = 100.0):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
for _ in range(num_layers)])
|
||||
self.base_frequency = frequency
|
||||
self._freq_cache: dict = {}
|
||||
|
||||
def _components(self, dim: int, seq_len: int, device, dtype):
|
||||
key = (dim, seq_len, device, dtype)
|
||||
if key not in self._freq_cache:
|
||||
exp = torch.arange(0, dim, 2, device=device).float() / dim
|
||||
inv_freq = 1.0 / (self.base_frequency ** exp)
|
||||
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
ang = torch.einsum("i,j->ij", pos, inv_freq)
|
||||
ang = ang.to(dtype)
|
||||
ang = torch.cat((ang, ang), dim=-1)
|
||||
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
|
||||
return self._freq_cache[key]
|
||||
|
||||
@staticmethod
|
||||
def _rotate(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1]
|
||||
x1, x2 = x[..., : d // 2], x[..., d // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_1d(self, tokens, positions, cos_c, sin_c):
|
||||
cos = F.embedding(positions, cos_c)[:, None, :, :]
|
||||
sin = F.embedding(positions, sin_c)[:, None, :, :]
|
||||
return (tokens * cos) + (self._rotate(tokens) * sin)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
feature_dim = tokens.size(-1) // 2
|
||||
max_pos = int(positions.max()) + 1
|
||||
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
|
||||
v, h = tokens.chunk(2, dim=-1)
|
||||
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
|
||||
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
|
||||
return torch.cat((v, h), dim=-1)
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn,
|
||||
qknorm_start: int = -1):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([
|
||||
Dino2Block(
|
||||
dim, num_heads, layer_norm_eps, dtype, device, operations,
|
||||
use_swiglu_ffn=use_swiglu_ffn,
|
||||
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
# Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions).
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
@ -122,16 +229,27 @@ class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Embeddings(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
def __init__(self, dim, dtype, device, operations,
|
||||
patch_size: int = 14, image_size: int = 518,
|
||||
use_mask_token: bool = True,
|
||||
num_camera_tokens: int = 0):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key.
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
if use_mask_token:
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.mask_token = None
|
||||
if num_camera_tokens > 0:
|
||||
# DA3 stores (ref_token, src_token) pairs that get injected at the
|
||||
# alt-attn boundary; see ``Dinov2Model._inject_camera_token``.
|
||||
self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.camera_token = None
|
||||
|
||||
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
|
||||
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
|
||||
@ -140,12 +258,22 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
patch_pos = pos_embed[:, 1:]
|
||||
N = patch_pos.shape[1]
|
||||
M = int(N ** 0.5)
|
||||
assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})"
|
||||
h0 = h_pixels // self.patch_size
|
||||
w0 = w_pixels // self.patch_size
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
# +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
# scale_factor is (height_scale, width_scale) -- height MUST come first;
|
||||
# swapping these only happens to work for square inputs and breaks
|
||||
# non-square paths like DA3-Small / DA3-Base multi-view.
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M)
|
||||
|
||||
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
|
||||
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
|
||||
assert (h0, w0) == patch_pos.shape[-2:], (
|
||||
f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match "
|
||||
f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); "
|
||||
f"check scale_factor axis order and +0.1 rounding workaround"
|
||||
)
|
||||
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)
|
||||
|
||||
@ -168,12 +296,51 @@ class Dinov2Model(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||
patch_size = config_dict.get("patch_size", 14)
|
||||
image_size = config_dict.get("image_size", 518)
|
||||
use_mask_token = config_dict.get("use_mask_token", True)
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
# DA3 extensions (all default to disabled).
|
||||
self.alt_start = config_dict.get("alt_start", -1)
|
||||
self.qknorm_start = config_dict.get("qknorm_start", -1)
|
||||
self.rope_start = config_dict.get("rope_start", -1)
|
||||
self.cat_token = config_dict.get("cat_token", False)
|
||||
rope_freq = config_dict.get("rope_freq", 100.0)
|
||||
|
||||
self.embed_dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.num_register_tokens = 0
|
||||
self.patch_start_idx = 1
|
||||
|
||||
if self.rope_start != -1 and rope_freq > 0:
|
||||
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
|
||||
self._position_getter = _PositionGetter()
|
||||
else:
|
||||
self.rope = None
|
||||
self._position_getter = None
|
||||
|
||||
# camera_token shape: (1, 2, dim) -> (ref_token, src_token).
|
||||
num_cam_tokens = 2 if self.alt_start != -1 else 0
|
||||
|
||||
self.embeddings = Dino2Embeddings(
|
||||
dim, dtype, device, operations,
|
||||
patch_size=patch_size, image_size=image_size,
|
||||
use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens,
|
||||
)
|
||||
self.encoder = Dino2Encoder(
|
||||
dim, heads, layer_norm_eps, num_layers, dtype, device, operations,
|
||||
use_swiglu_ffn=use_swiglu_ffn,
|
||||
qknorm_start=self.qknorm_start,
|
||||
)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
if self.alt_start != -1:
|
||||
raise RuntimeError(
|
||||
"Dinov2Model.forward() is the backward-compatible CLIP-vision path and does not "
|
||||
"apply DA3 extensions (RoPE, alternating attention, camera-token injection). "
|
||||
"Use get_intermediate_layers_da3() for Depth Anything 3 models."
|
||||
)
|
||||
x = self.embeddings(pixel_values)
|
||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||
x = self.layernorm(x)
|
||||
@ -181,6 +348,7 @@ class Dinov2Model(torch.nn.Module):
|
||||
return x, i, pooled_output, None
|
||||
|
||||
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
|
||||
"""Single-view multi-layer feature extraction."""
|
||||
x = self.embeddings(pixel_values)
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
n_layers = len(self.encoder.layer)
|
||||
@ -197,3 +365,132 @@ class Dinov2Model(torch.nn.Module):
|
||||
if i >= max_idx:
|
||||
break
|
||||
return [cache[i] for i in resolved]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Depth Anything 3 forward
|
||||
# ------------------------------------------------------------------
|
||||
def _prepare_rope_positions(self, B, S, H, W, device):
|
||||
if self.rope is None:
|
||||
return None, None
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
pos = self._position_getter(B * S, ph, pw, device=device)
|
||||
# Shift so the cls/cam token at position 0 is reserved for "no diff".
|
||||
pos = pos + 1
|
||||
cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
|
||||
# Per-view local: real grid positions for patches, 0 for cls token.
|
||||
pos_local = torch.cat([cls_pos, pos], dim=1)
|
||||
# Global (across views): same grid positions; cls token still at 0,
|
||||
# but patches share the same positions in every view.
|
||||
pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1)
|
||||
return pos_local, pos_global
|
||||
|
||||
def _inject_camera_token(self, x: torch.Tensor, B: int, S: int, cam_token: "torch.Tensor | None") -> torch.Tensor:
|
||||
# x: (B, S, N, C). Replace token at index 0 with the camera token.
|
||||
if cam_token is not None:
|
||||
inj = cam_token
|
||||
else:
|
||||
ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype)
|
||||
ref_token = ct[:, :1].expand(B, -1, -1)
|
||||
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
|
||||
inj = torch.cat([ref_token, src_token], dim=1)
|
||||
x = x.clone()
|
||||
x[:, :, 0] = inj
|
||||
return x
|
||||
|
||||
def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None, ref_view_strategy="saddle_balanced", export_feat_layers=None):
|
||||
"""Multi-view multi-layer feature extraction used by Depth Anything 3."""
|
||||
if pixel_values.ndim == 4:
|
||||
pixel_values = pixel_values.unsqueeze(1)
|
||||
assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \
|
||||
f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}"
|
||||
B, S, _, H, W = pixel_values.shape
|
||||
|
||||
# Patch + cls + (interpolated) pos embed for each view.
|
||||
x = pixel_values.reshape(B * S, 3, H, W)
|
||||
x = self.embeddings(x) # (B*S, 1+N, C)
|
||||
x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C)
|
||||
|
||||
pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device)
|
||||
# optimized_attention is only used by blocks without QK-norm/RoPE
|
||||
# (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA.
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
out_set = set(out_layers)
|
||||
export_set = set(export_feat_layers) if export_feat_layers else set()
|
||||
outputs: list[torch.Tensor] = []
|
||||
aux_outputs: list[torch.Tensor] = []
|
||||
local_x = x
|
||||
b_idx = None
|
||||
|
||||
|
||||
for i, blk in enumerate(self.encoder.layer):
|
||||
apply_rope = self.rope is not None and i >= self.rope_start
|
||||
block_rope = self.rope if apply_rope else None
|
||||
l_pos = pos_local if apply_rope else None
|
||||
g_pos = pos_global if apply_rope else None
|
||||
|
||||
# Reference-view selection threshold: matches the upstream constant
|
||||
# THRESH_FOR_REF_SELECTION = 3. Skipped when a user-supplied
|
||||
# cam_token is provided (camera info already pins the geometry).
|
||||
if (self.alt_start != -1 and i == self.alt_start - 1 and S >= THRESH_FOR_REF_SELECTION and cam_token is None):
|
||||
b_idx = select_reference_view(x, strategy=ref_view_strategy)
|
||||
x = reorder_by_reference(x, b_idx)
|
||||
local_x = reorder_by_reference(local_x, b_idx)
|
||||
|
||||
if self.alt_start != -1 and i == self.alt_start:
|
||||
x = self._inject_camera_token(x, B, S, cam_token)
|
||||
|
||||
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
|
||||
# Global attention across views: flatten S into the seq dim.
|
||||
t = x.reshape(B, S * x.shape[-2], x.shape[-1])
|
||||
p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None
|
||||
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
|
||||
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
|
||||
else:
|
||||
# Per-view local attention.
|
||||
t = x.reshape(B * S, x.shape[-2], x.shape[-1])
|
||||
p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None
|
||||
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
|
||||
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
|
||||
local_x = x
|
||||
|
||||
if i in out_set:
|
||||
if self.cat_token:
|
||||
out_x = torch.cat([local_x, x], dim=-1)
|
||||
else:
|
||||
out_x = x
|
||||
# Restore original view order on the way out so heads see views
|
||||
# in the user's expected order.
|
||||
if b_idx is not None and self.alt_start != -1:
|
||||
out_x = restore_original_order(out_x, b_idx)
|
||||
outputs.append(out_x)
|
||||
|
||||
if i in export_set:
|
||||
aux = x
|
||||
if b_idx is not None and self.alt_start != -1:
|
||||
aux = restore_original_order(aux, b_idx)
|
||||
aux_outputs.append(aux)
|
||||
|
||||
# Apply final norm. When cat_token is set, only the right half
|
||||
# ("global" features) is normalised; the left half is left as-is to
|
||||
# match the upstream DA3 head signature.
|
||||
normed: list[torch.Tensor] = []
|
||||
cls_tokens: list[torch.Tensor] = []
|
||||
for out_x in outputs:
|
||||
cls_tokens.append(out_x[:, :, 0])
|
||||
if out_x.shape[-1] == self.embed_dim:
|
||||
normed.append(self.layernorm(out_x))
|
||||
elif out_x.shape[-1] == self.embed_dim * 2:
|
||||
left = out_x[..., :self.embed_dim]
|
||||
right = self.layernorm(out_x[..., self.embed_dim:])
|
||||
normed.append(torch.cat([left, right], dim=-1))
|
||||
else:
|
||||
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
|
||||
|
||||
# Drop cls/cam token from the patch sequence.
|
||||
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
|
||||
|
||||
# Final layernorm + drop cls token from auxiliary features too.
|
||||
aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :]
|
||||
for o in aux_outputs]
|
||||
return list(zip(normed, cls_tokens)), aux_normed
|
||||
|
||||
25
comfy/ldm/colormap.py
Normal file
25
comfy/ldm/colormap.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""Colormap utilities for depth and geometry visualisation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def turbo(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Anton Mikhailov polynomial approximation of the Turbo colormap.
|
||||
|
||||
Args:
|
||||
x: Float tensor with values in [0, 1].
|
||||
|
||||
Returns:
|
||||
RGB tensor of the same shape as ``x`` with a trailing size-3 dimension.
|
||||
"""
|
||||
x = x.clamp(0.0, 1.0)
|
||||
x2 = x * x
|
||||
x3 = x2 * x
|
||||
x4 = x2 * x2
|
||||
x5 = x4 * x
|
||||
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
|
||||
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
|
||||
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
|
||||
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
|
||||
177
comfy/ldm/depth_anything_3/camera.py
Normal file
177
comfy/ldm/depth_anything_3/camera.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""Camera-token encoder and decoder for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from .transform import affine_inverse, extri_intri_to_pose_encoding
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Building blocks (mirror depth_anything_3.model.utils.{attention,block})
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Mlp(nn.Module):
|
||||
"""Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``."""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = operations.Linear(in_features, hidden_features, bias=True, device=device, dtype=dtype)
|
||||
self.fc2 = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(F.gelu(self.fc1(x)))
|
||||
|
||||
|
||||
class _LayerScale(nn.Module):
|
||||
"""Per-channel learnable scaling. Matches upstream LayerScale."""
|
||||
|
||||
def __init__(self, dim, *, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gamma.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
class _Attention(nn.Module):
|
||||
""" Self-attention with fused QKV projection. Mirrors upstream utils.attention.Attention;
|
||||
Layout matches the HF safetensors (attn.qkv.{weight,bias} and attn.proj.{weight,bias})."""
|
||||
|
||||
def __init__(self, dim, num_heads, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, C)
|
||||
q, k, v = qkv.unbind(2) # each (B, N, C)
|
||||
attn_fn = optimized_attention_for_device(x.device, small_input=True)
|
||||
out = attn_fn(q, k, v, heads=self.num_heads)
|
||||
return self.proj(out)
|
||||
|
||||
|
||||
class _Block(nn.Module):
|
||||
"""Pre-norm transformer block with LayerScale. Used by :class:CameraEnc. Layout follows upstream utils.block.Block."""
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.attn = _Attention(dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
|
||||
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
||||
self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.ls1(self.attn(self.norm1(x)))
|
||||
x = x + self.ls2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class CameraEnc(nn.Module):
|
||||
"""Encode per-view (extrinsics, intrinsics) into a camera token.
|
||||
|
||||
Maps a 9-D pose-encoding vector through a small MLP up to the backbone's
|
||||
``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output
|
||||
has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start``
|
||||
of the DINOv2 backbone in place of the cls token.
|
||||
|
||||
Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_out: int = 1024,
|
||||
dim_in: int = 9,
|
||||
trunk_depth: int = 4,
|
||||
target_dim: int = 9,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
init_values: float = 0.01,
|
||||
*,
|
||||
device=None, dtype=None, operations=None,
|
||||
**_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.target_dim = target_dim
|
||||
self.trunk_depth = trunk_depth
|
||||
self.trunk = nn.Sequential(*[
|
||||
_Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(trunk_depth)
|
||||
])
|
||||
self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
|
||||
self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
|
||||
self.pose_branch = _Mlp(
|
||||
in_features=dim_in,
|
||||
hidden_features=dim_out // 2,
|
||||
out_features=dim_out,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor,
|
||||
image_size_hw) -> torch.Tensor:
|
||||
"""Encode camera parameters into ``(B, S, dim_out)`` tokens."""
|
||||
c2ws = affine_inverse(extrinsics)
|
||||
pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw)
|
||||
tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype))
|
||||
tokens = self.token_norm(tokens)
|
||||
tokens = self.trunk(tokens)
|
||||
tokens = self.trunk_norm(tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
class CameraDec(nn.Module):
|
||||
"""Decode the final cam token into a 9-D pose encoding.
|
||||
|
||||
Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is
|
||||
always predicted by the network; the quaternion and FoV can either be
|
||||
predicted or supplied via ``camera_encoding`` (used at training time
|
||||
when GT cameras are available -- not exercised at inference here).
|
||||
|
||||
Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int = 1536,
|
||||
*, device=None, dtype=None, operations=None, **_kwargs):
|
||||
super().__init__()
|
||||
d = dim_in
|
||||
self.backbone = nn.Sequential(
|
||||
operations.Linear(d, d, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
operations.Linear(d, d, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype)
|
||||
self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype)
|
||||
self.fc_fov = nn.Sequential(
|
||||
operations.Linear(d, 2, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, feat: torch.Tensor,
|
||||
camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor:
|
||||
"""Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc."""
|
||||
B, N = feat.shape[:2]
|
||||
feat = feat.reshape(B * N, -1)
|
||||
feat = self.backbone(feat)
|
||||
out_t = self.fc_t(feat.float()).reshape(B, N, 3)
|
||||
if camera_encoding is None:
|
||||
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
|
||||
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
|
||||
else:
|
||||
out_qvec = camera_encoding[..., 3:7]
|
||||
out_fov = camera_encoding[..., -2:]
|
||||
return torch.cat([out_t, out_qvec, out_fov], dim=-1)
|
||||
489
comfy/ldm/depth_anything_3/dpt.py
Normal file
489
comfy/ldm/depth_anything_3/dpt.py
Normal file
@ -0,0 +1,489 @@
|
||||
"""DPT / DualDPT heads for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Permute(nn.Module):
|
||||
def __init__(self, dims: Tuple[int, ...]):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.permute(*self.dims)
|
||||
|
||||
|
||||
def _custom_interpolate(
|
||||
x: torch.Tensor,
|
||||
size: Optional[Tuple[int, int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
mode: str = "bilinear",
|
||||
align_corners: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if size is None:
|
||||
assert scale_factor is not None
|
||||
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
||||
INT_MAX = 1610612736
|
||||
total = size[0] * size[1] * x.shape[0] * x.shape[1]
|
||||
if total > INT_MAX:
|
||||
chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
|
||||
outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks]
|
||||
return torch.cat(outs, dim=0).contiguous()
|
||||
return F.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
||||
|
||||
|
||||
def _create_uv_grid(width: int, height: int, aspect_ratio: float, dtype, device) -> torch.Tensor:
|
||||
"""Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span)."""
|
||||
diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5
|
||||
span_x = aspect_ratio / diag_factor
|
||||
span_y = 1.0 / diag_factor
|
||||
left_x = -span_x * (width - 1) / width
|
||||
right_x = span_x * (width - 1) / width
|
||||
top_y = -span_y * (height - 1) / height
|
||||
bottom_y = span_y * (height - 1) / height
|
||||
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
||||
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
||||
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
||||
return torch.stack((uu, vv), dim=-1) # (H, W, 2)
|
||||
|
||||
|
||||
def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor:
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
|
||||
omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0))
|
||||
pos = pos.reshape(-1)
|
||||
out = torch.einsum("m,d->md", pos, omega)
|
||||
return torch.cat([out.sin(), out.cos()], dim=1).float()
|
||||
|
||||
|
||||
def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100.0) -> torch.Tensor:
|
||||
H, W, _ = pos_grid.shape
|
||||
pos_flat = pos_grid.reshape(-1, 2)
|
||||
emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0)
|
||||
emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0)
|
||||
emb = torch.cat([emb_x, emb_y], dim=-1)
|
||||
return emb.view(H, W, embed_dim)
|
||||
|
||||
|
||||
def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
||||
"""Stateless UV positional embedding added to a feature map (B, C, h, w)."""
|
||||
pw, ph = x.shape[-1], x.shape[-2]
|
||||
pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
||||
pe = _position_grid_to_embed(pe, x.shape[1]) * ratio
|
||||
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype)
|
||||
return x + pe
|
||||
|
||||
|
||||
def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor:
|
||||
act = (activation or "linear").lower()
|
||||
if act == "exp":
|
||||
return torch.exp(x)
|
||||
if act == "expp1":
|
||||
return torch.exp(x) + 1
|
||||
if act == "expm1":
|
||||
return torch.expm1(x)
|
||||
if act == "relu":
|
||||
return torch.relu(x)
|
||||
if act == "sigmoid":
|
||||
return torch.sigmoid(x)
|
||||
if act == "softplus":
|
||||
return F.softplus(x)
|
||||
if act == "tanh":
|
||||
return torch.tanh(x)
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Fusion building blocks
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
def __init__(self, features: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
|
||||
self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
|
||||
self.activation = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
def __init__(self, features: int, has_residual: bool = True, align_corners: bool = True, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.align_corners = align_corners
|
||||
self.has_residual = has_residual
|
||||
if has_residual:
|
||||
self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.resConfUnit1 = None
|
||||
self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
|
||||
self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
|
||||
y = xs[0]
|
||||
if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
|
||||
y = y + self.resConfUnit1(xs[1])
|
||||
y = self.resConfUnit2(y)
|
||||
if size is None:
|
||||
up_kwargs = {"scale_factor": 2.0}
|
||||
else:
|
||||
up_kwargs = {"size": size}
|
||||
y = _custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
|
||||
y = self.out_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class _Scratch(nn.Module):
|
||||
"""Container that mirrors upstream ``scratch`` attribute layout."""
|
||||
|
||||
|
||||
def _make_scratch(in_shape: List[int], out_shape: int, device=None, dtype=None, operations=None) -> _Scratch:
|
||||
scratch = _Scratch()
|
||||
scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_fusion_block(features: int, has_residual: bool = True, device=None, dtype=None, operations=None) -> FeatureFusionBlock:
|
||||
return FeatureFusionBlock(features, has_residual=has_residual, align_corners=True, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DPT (single head + optional sky head) -- used by DA3Mono/Metric
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DPT(nn.Module):
|
||||
"""Single-head DPT used by DA3Mono-Large and DA3Metric-Large."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 1,
|
||||
activation: str = "exp",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
pos_embed: bool = False,
|
||||
down_ratio: int = 1,
|
||||
head_name: str = "depth",
|
||||
use_sky_head: bool = True,
|
||||
sky_name: str = "sky",
|
||||
sky_activation: str = "relu",
|
||||
norm_type: str = "idt",
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.down_ratio = down_ratio
|
||||
self.head_main = head_name
|
||||
self.sky_name = sky_name
|
||||
self.out_dim = output_dim
|
||||
self.has_conf = output_dim > 1
|
||||
self.use_sky_head = use_sky_head
|
||||
self.sky_activation = sky_activation
|
||||
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
||||
|
||||
if norm_type == "layer":
|
||||
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
self.scratch.output_conv1 = operations.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
if self.use_sky_head:
|
||||
self.scratch.sky_output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
# feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C)
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
resized = []
|
||||
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
||||
x = feats_flat[take_idx][:, patch_start_idx:]
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
|
||||
x = self.projects[stage_idx](x)
|
||||
if self.pos_embed:
|
||||
x = _add_pos_embed(x, W, H)
|
||||
x = self.resize_layers[stage_idx](x)
|
||||
resized.append(x)
|
||||
|
||||
l1_rn = self.scratch.layer1_rn(resized[0])
|
||||
l2_rn = self.scratch.layer2_rn(resized[1])
|
||||
l3_rn = self.scratch.layer3_rn(resized[2])
|
||||
l4_rn = self.scratch.layer4_rn(resized[3])
|
||||
|
||||
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
||||
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
|
||||
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
|
||||
out = self.scratch.refinenet1(out, l1_rn)
|
||||
|
||||
h_out = int(ph * self.patch_size / self.down_ratio)
|
||||
w_out = int(pw * self.patch_size / self.down_ratio)
|
||||
|
||||
fused = self.scratch.output_conv1(out)
|
||||
fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
|
||||
if self.pos_embed:
|
||||
fused = _add_pos_embed(fused, W, H)
|
||||
feat = fused
|
||||
|
||||
main_logits = self.scratch.output_conv2(feat)
|
||||
outs = {}
|
||||
if self.has_conf:
|
||||
fmap = main_logits.permute(0, 2, 3, 1)
|
||||
pred = _apply_activation(fmap[..., :-1], self.activation)
|
||||
conf = _apply_activation(fmap[..., -1], self.conf_activation)
|
||||
outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1])
|
||||
outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:])
|
||||
else:
|
||||
pred = _apply_activation(main_logits, self.activation)
|
||||
outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:])
|
||||
|
||||
if self.use_sky_head:
|
||||
sky_logits = self.scratch.sky_output_conv2(feat)
|
||||
if self.sky_activation.lower() == "sigmoid":
|
||||
sky = torch.sigmoid(sky_logits)
|
||||
elif self.sky_activation.lower() == "relu":
|
||||
sky = F.relu(sky_logits)
|
||||
else:
|
||||
sky = sky_logits
|
||||
outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:])
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DualDPT(nn.Module):
|
||||
"""Two-head DPT used by DA3-Small / DA3-Base."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 2,
|
||||
activation: str = "exp",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
pos_embed: bool = True,
|
||||
down_ratio: int = 1,
|
||||
aux_pyramid_levels: int = 4,
|
||||
aux_out1_conv_num: int = 5,
|
||||
head_names: Tuple[str, str] = ("depth", "ray"),
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.down_ratio = down_ratio
|
||||
self.aux_levels = aux_pyramid_levels
|
||||
self.aux_out1_conv_num = aux_out1_conv_num
|
||||
self.head_main, self.head_aux = head_names
|
||||
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
||||
# Toggle the auxiliary ray branch at runtime. Default off (mono path).
|
||||
# DepthAnything3Net flips this on when running multi-view + ray-pose.
|
||||
self.enable_aux: bool = False
|
||||
|
||||
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
|
||||
# Main fusion chain
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
# Auxiliary fusion chain (separate copies)
|
||||
self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
# Main head neck + final projection
|
||||
self.scratch.output_conv1 = operations.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
# Aux pre-head per level (multi-level pyramid)
|
||||
self.scratch.output_conv1_aux = nn.ModuleList([
|
||||
self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(self.aux_levels)
|
||||
])
|
||||
|
||||
# Aux final projection per level (includes LayerNorm permute path).
|
||||
ln_seq = [Permute((0, 2, 3, 1)),
|
||||
operations.LayerNorm(head_features_2, device=device, dtype=dtype),
|
||||
Permute((0, 3, 1, 2))]
|
||||
self.scratch.output_conv2_aux = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
*ln_seq,
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
for _ in range(self.aux_levels)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential:
|
||||
# aux_out1_conv_num=5 in all Apache-2.0 variants.
|
||||
return nn.Sequential(
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
resized = []
|
||||
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
||||
x = feats_flat[take_idx][:, patch_start_idx:]
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
|
||||
x = self.projects[stage_idx](x)
|
||||
if self.pos_embed:
|
||||
x = _add_pos_embed(x, W, H)
|
||||
x = self.resize_layers[stage_idx](x)
|
||||
resized.append(x)
|
||||
|
||||
l1_rn = self.scratch.layer1_rn(resized[0])
|
||||
l2_rn = self.scratch.layer2_rn(resized[1])
|
||||
l3_rn = self.scratch.layer3_rn(resized[2])
|
||||
l4_rn = self.scratch.layer4_rn(resized[3])
|
||||
|
||||
# Main pyramid (output_conv1 is applied inside the upstream `_fuse`,
|
||||
# before interpolation -- replicate that order here).
|
||||
m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
a4 = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
|
||||
aux_pyr = [a4]
|
||||
m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet3_aux(aux_pyr[-1], l3_rn, size=l2_rn.shape[2:]))
|
||||
m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet2_aux(aux_pyr[-1], l2_rn, size=l1_rn.shape[2:]))
|
||||
m = self.scratch.refinenet1(m, l1_rn)
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet1_aux(aux_pyr[-1], l1_rn))
|
||||
m = self.scratch.output_conv1(m)
|
||||
|
||||
h_out = int(ph * self.patch_size / self.down_ratio)
|
||||
w_out = int(pw * self.patch_size / self.down_ratio)
|
||||
|
||||
m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True)
|
||||
if self.pos_embed:
|
||||
m = _add_pos_embed(m, W, H)
|
||||
main_logits = self.scratch.output_conv2(m)
|
||||
fmap = main_logits.permute(0, 2, 3, 1)
|
||||
depth_pred = _apply_activation(fmap[..., :-1], self.activation)
|
||||
depth_conf = _apply_activation(fmap[..., -1], self.conf_activation)
|
||||
|
||||
outs = {
|
||||
self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]),
|
||||
f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]),
|
||||
}
|
||||
|
||||
if self.enable_aux:
|
||||
# Auxiliary "ray" head (multi-level inside) -- only the last level
|
||||
# is returned. Mirrors upstream ``DualDPT._fuse`` + ``_forward_impl``:
|
||||
# each aux pyramid level goes through ``output_conv1_aux[i]``
|
||||
# (5-layer conv stack that ends at ``features // 2`` channels),
|
||||
# then the last level optionally gets a pos-embed and finally
|
||||
# ``output_conv2_aux[-1]``.
|
||||
aux_processed = [
|
||||
self.scratch.output_conv1_aux[i](a) for i, a in enumerate(aux_pyr)
|
||||
]
|
||||
last_aux = aux_processed[-1]
|
||||
if self.pos_embed:
|
||||
last_aux = _add_pos_embed(last_aux, W, H)
|
||||
last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux)
|
||||
fmap_last = last_aux_logits.permute(0, 2, 3, 1)
|
||||
# Channels: [ray(6), ray_conf(1)]; ray uses 'linear' activation.
|
||||
aux_pred = fmap_last[..., :-1]
|
||||
aux_conf = _apply_activation(fmap_last[..., -1], self.conf_activation)
|
||||
outs[self.head_aux] = aux_pred.view(B, S, *aux_pred.shape[1:])
|
||||
outs[f"{self.head_aux}_conf"] = aux_conf.view(B, S, *aux_conf.shape[1:])
|
||||
|
||||
return outs
|
||||
236
comfy/ldm/depth_anything_3/model.py
Normal file
236
comfy/ldm/depth_anything_3/model.py
Normal file
@ -0,0 +1,236 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.image_encoders.dino2 import Dinov2Model
|
||||
|
||||
from .camera import CameraDec, CameraEnc
|
||||
from .dpt import DPT, DualDPT
|
||||
from .ray_pose import get_extrinsic_from_camray
|
||||
from .transform import affine_inverse, pose_encoding_to_extri_intri
|
||||
|
||||
|
||||
_HEAD_REGISTRY = {
|
||||
"dpt": DPT,
|
||||
"dualdpt": DualDPT,
|
||||
}
|
||||
|
||||
|
||||
# Backbone presets (mirror the upstream DINOv2 ViT variants).
|
||||
_BACKBONE_PRESETS = {
|
||||
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
|
||||
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
|
||||
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
|
||||
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
|
||||
}
|
||||
|
||||
|
||||
def _build_backbone_config(
|
||||
backbone_name: str,
|
||||
*,
|
||||
alt_start: int,
|
||||
qknorm_start: int,
|
||||
rope_start: int,
|
||||
cat_token: bool,
|
||||
) -> dict:
|
||||
if backbone_name not in _BACKBONE_PRESETS:
|
||||
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
|
||||
cfg = dict(_BACKBONE_PRESETS[backbone_name])
|
||||
cfg.update(dict(
|
||||
layer_norm_eps=1e-6,
|
||||
patch_size=14,
|
||||
image_size=518,
|
||||
# No mask_token in DA3 weights; omit param to avoid load warnings.
|
||||
use_mask_token=False,
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
rope_freq=100.0,
|
||||
))
|
||||
return cfg
|
||||
|
||||
|
||||
class DepthAnything3Net(nn.Module):
|
||||
|
||||
PATCH_SIZE = 14
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# --- Backbone ---
|
||||
backbone_name: str = "vitl",
|
||||
out_layers: Sequence[int] = (4, 11, 17, 23),
|
||||
alt_start: int = -1,
|
||||
qknorm_start: int = -1,
|
||||
rope_start: int = -1,
|
||||
cat_token: bool = False,
|
||||
# --- Head ---
|
||||
head_type: str = "dpt", # dpt or dualdpt
|
||||
head_dim_in: int = 1024,
|
||||
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
|
||||
head_features: int = 256,
|
||||
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
head_use_sky_head: bool = True, # ignored by DualDPT
|
||||
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
|
||||
# --- Camera (multi-view) ---
|
||||
has_cam_enc: bool = False,
|
||||
has_cam_dec: bool = False,
|
||||
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
|
||||
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
|
||||
# ComfyUI plumbing
|
||||
device=None, dtype=None, operations=None,
|
||||
**_ignored,
|
||||
):
|
||||
super().__init__()
|
||||
head_cls = _HEAD_REGISTRY[head_type.lower()]
|
||||
self.head_type = head_type.lower()
|
||||
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
|
||||
self.has_conf = head_output_dim > 1
|
||||
self.out_layers = list(out_layers)
|
||||
|
||||
backbone_cfg = _build_backbone_config(
|
||||
backbone_name,
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
)
|
||||
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
|
||||
|
||||
head_kwargs = dict(
|
||||
dim_in=head_dim_in,
|
||||
patch_size=self.PATCH_SIZE,
|
||||
output_dim=head_output_dim,
|
||||
features=head_features,
|
||||
out_channels=tuple(head_out_channels),
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
if self.head_type == "dpt":
|
||||
head_kwargs.update(
|
||||
use_sky_head=head_use_sky_head,
|
||||
pos_embed=(False if head_pos_embed is None else head_pos_embed),
|
||||
)
|
||||
else: # dualdpt
|
||||
head_kwargs.update(
|
||||
pos_embed=(True if head_pos_embed is None else head_pos_embed),
|
||||
)
|
||||
self.head = head_cls(**head_kwargs)
|
||||
|
||||
# Built only if checkpoint has weights; cam_enc output dim == embed_dim.
|
||||
embed_dim = backbone_cfg["hidden_size"]
|
||||
if has_cam_enc:
|
||||
self.cam_enc = CameraEnc(
|
||||
dim_out=cam_dim_out if cam_dim_out is not None else embed_dim,
|
||||
num_heads=max(1, embed_dim // 64),
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
else:
|
||||
self.cam_enc = None
|
||||
if has_cam_dec:
|
||||
default_dim = embed_dim * (2 if cat_token else 1)
|
||||
self.cam_dec = CameraDec(
|
||||
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
else:
|
||||
self.cam_dec = None
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
extrinsics: Optional[torch.Tensor] = None,
|
||||
intrinsics: Optional[torch.Tensor] = None,
|
||||
*,
|
||||
use_ray_pose: bool = False,
|
||||
ref_view_strategy: str = "saddle_balanced",
|
||||
export_feat_layers: Optional[Sequence[int]] = None,
|
||||
**_unused,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Run depth and optionally pose prediction."""
|
||||
if image.ndim == 4:
|
||||
image = image.unsqueeze(1) # (B, 1, 3, H, W)
|
||||
assert image.ndim == 5 and image.shape[2] == 3, \
|
||||
f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}"
|
||||
|
||||
B, S, _, H, W = image.shape
|
||||
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
|
||||
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
|
||||
|
||||
# Camera-token preparation (multi-view path).
|
||||
cam_token = None
|
||||
if extrinsics is not None and intrinsics is not None and self.cam_enc is not None:
|
||||
cam_token = self.cam_enc(extrinsics, intrinsics, (H, W))
|
||||
|
||||
# Toggle aux ray output on/off depending on what the caller asked for.
|
||||
if isinstance(self.head, DualDPT):
|
||||
self.head.enable_aux = bool(use_ray_pose)
|
||||
|
||||
feats, aux_feats = self.backbone.get_intermediate_layers_da3(
|
||||
image, self.out_layers, cam_token=cam_token,
|
||||
ref_view_strategy=ref_view_strategy,
|
||||
export_feat_layers=export_feat_layers,
|
||||
)
|
||||
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
|
||||
|
||||
# Pose prediction.
|
||||
out: Dict[str, torch.Tensor] = {}
|
||||
if use_ray_pose and "ray" in head_out and "ray_conf" in head_out:
|
||||
ray = head_out["ray"]
|
||||
ray_conf = head_out["ray_conf"]
|
||||
extr_c2w, focal, pp = get_extrinsic_from_camray(
|
||||
ray, ray_conf, ray.shape[-3], ray.shape[-2],
|
||||
)
|
||||
# Match the upstream output: w2c, drop the homogeneous row.
|
||||
extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :]
|
||||
# Build pixel-space intrinsics from the normalised focal/pp output.
|
||||
intr = torch.eye(3, device=ray.device, dtype=ray.dtype)
|
||||
intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone()
|
||||
intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W
|
||||
intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H
|
||||
intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5
|
||||
intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5
|
||||
out["extrinsics"] = extr_w2c
|
||||
out["intrinsics"] = intr
|
||||
elif self.cam_dec is not None and S > 1:
|
||||
# Decode the cam-token of the final out_layer into a pose encoding.
|
||||
cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec)
|
||||
pose_enc = self.cam_dec(cam_feat)
|
||||
c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W))
|
||||
# Match the upstream output convention: w2c (world->camera), 3x4.
|
||||
c2w_4x4 = torch.cat([
|
||||
c2w_3x4,
|
||||
torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype)
|
||||
.view(1, 1, 1, 4).expand(B, S, 1, 4),
|
||||
], dim=-2)
|
||||
out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :]
|
||||
out["intrinsics"] = intr
|
||||
|
||||
# Flatten the views axis for per-pixel outputs (depth/conf/sky) so the
|
||||
# per-image consumer keeps its (B*S, H, W) interface.
|
||||
for k, v in head_out.items():
|
||||
if k in ("ray", "ray_conf"):
|
||||
# Keep multi-view shape for downstream pose work.
|
||||
out[k] = v
|
||||
elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
|
||||
out[k] = v.reshape(B * S, *v.shape[2:])
|
||||
else:
|
||||
out[k] = v
|
||||
|
||||
if export_feat_layers:
|
||||
out["aux_features"] = self._reshape_aux_features(aux_feats, H, W)
|
||||
return out
|
||||
|
||||
def _reshape_aux_features(self, aux_feats, H: int, W: int):
|
||||
"""Reshape (B, S, N, C) aux features into (B, S, h_p, w_p, C)."""
|
||||
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
|
||||
out = []
|
||||
for f in aux_feats:
|
||||
B, S, N, C = f.shape
|
||||
assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}"
|
||||
out.append(f.reshape(B, S, ph, pw, C))
|
||||
return out
|
||||
128
comfy/ldm/depth_anything_3/preprocess.py
Normal file
128
comfy/ldm/depth_anything_3/preprocess.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""Input/output preprocessing helpers for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
|
||||
PATCH_SIZE = 14
|
||||
|
||||
# ImageNet normalization constants used during DA3 training.
|
||||
_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
|
||||
_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int:
|
||||
down = (x // patch) * patch
|
||||
up = down + patch
|
||||
return up if abs(up - x) <= abs(x - down) else down
|
||||
|
||||
|
||||
def compute_target_size(orig_h: int, orig_w: int, process_res: int, method: str = "upper_bound_resize") -> Tuple[int, int]:
|
||||
"""Compute (target_h, target_w) for a single image.
|
||||
upper_bound_resize: scale longest side to process_res, then round each dim to nearest multiple of 14 (default upstream method).
|
||||
lower_bound_resize: scale shortest side to process_res, then round."""
|
||||
|
||||
if method == "upper_bound_resize":
|
||||
longest = max(orig_h, orig_w)
|
||||
scale = process_res / float(longest)
|
||||
elif method == "lower_bound_resize":
|
||||
shortest = min(orig_h, orig_w)
|
||||
scale = process_res / float(shortest)
|
||||
else:
|
||||
raise ValueError(f"Unsupported process_res_method: {method}")
|
||||
|
||||
new_w = max(1, _round_to_patch(int(round(orig_w * scale))))
|
||||
new_h = max(1, _round_to_patch(int(round(orig_h * scale))))
|
||||
return new_h, new_w
|
||||
|
||||
|
||||
def preprocess_image(image: torch.Tensor, process_res: int = 504, method: str = "upper_bound_resize") -> torch.Tensor:
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
B, H, W, _ = image.shape
|
||||
target_h, target_w = compute_target_size(H, W, process_res, method)
|
||||
|
||||
# (B, H, W, 3) -> (B, 3, H, W)
|
||||
x = image.movedim(-1, 1).contiguous()
|
||||
if (target_h, target_w) != (H, W):
|
||||
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale).
|
||||
# Lanczos in ``common_upscale`` is anti-aliased and produces the
|
||||
# closest pixel-wise match in a sweep across {bilinear, bicubic,
|
||||
# area, lanczos, bislerp}. Used in both directions for simplicity.
|
||||
x = comfy.utils.common_upscale(x.float(), target_w, target_h, "lanczos", "disabled",)
|
||||
x = x.clamp(0.0, 1.0)
|
||||
|
||||
mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
x = (x - mean) / std
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Output post-processing (sky-aware clipping for Mono/Metric variants)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
|
||||
"""Boolean mask: True for non-sky pixels (sky probability < threshold)."""
|
||||
return sky_prediction < threshold
|
||||
|
||||
|
||||
def apply_sky_aware_clip(depth: torch.Tensor, sky: torch.Tensor, threshold: float = 0.3, quantile: float = 0.99) -> torch.Tensor:
|
||||
"""Clips sky regions to the 99th percentile of non-sky depth. Returns a new depth tensor."""
|
||||
non_sky = compute_non_sky_mask(sky, threshold=threshold)
|
||||
if non_sky.sum() <= 10 or (~non_sky).sum() <= 10:
|
||||
return depth.clone()
|
||||
|
||||
non_sky_depth = depth[non_sky]
|
||||
if non_sky_depth.numel() > 100_000:
|
||||
idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device)
|
||||
sampled = non_sky_depth[idx]
|
||||
else:
|
||||
sampled = non_sky_depth
|
||||
|
||||
max_depth = torch.quantile(sampled, quantile)
|
||||
out = depth.clone()
|
||||
out[~non_sky] = max_depth
|
||||
return out
|
||||
|
||||
|
||||
def normalize_depth_v2_style(depth: torch.Tensor, sky: torch.Tensor | None = None, low_quantile: float = 0.01, high_quantile: float = 0.99) -> torch.Tensor:
|
||||
"""V2-style normalization computes percentile bounds over non-sky pixels (when available), then maps depth into [0, 1] with near = white (1.0)."""
|
||||
if sky is not None:
|
||||
mask = compute_non_sky_mask(sky)
|
||||
if mask.any():
|
||||
valid = depth[mask]
|
||||
else:
|
||||
valid = depth.flatten()
|
||||
else:
|
||||
valid = depth.flatten()
|
||||
|
||||
if valid.numel() > 100_000:
|
||||
idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device)
|
||||
sample = valid[idx]
|
||||
else:
|
||||
sample = valid
|
||||
|
||||
lo = torch.quantile(sample, low_quantile)
|
||||
hi = torch.quantile(sample, high_quantile)
|
||||
rng = (hi - lo).clamp(min=1e-6)
|
||||
norm = ((depth - lo) / rng).clamp(0.0, 1.0)
|
||||
# Nearer pixels are brighter (1.0)
|
||||
norm = 1.0 - norm
|
||||
if sky is not None:
|
||||
# Sky pixels become black (far / unknown)
|
||||
sky_mask = ~compute_non_sky_mask(sky)
|
||||
norm = torch.where(sky_mask, torch.zeros_like(norm), norm)
|
||||
return norm
|
||||
|
||||
|
||||
def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor:
|
||||
"""Simple per-frame min/max normalization with near=1.0 convention."""
|
||||
lo = depth.amin(dim=(-2, -1), keepdim=True)
|
||||
hi = depth.amax(dim=(-2, -1), keepdim=True)
|
||||
rng = (hi - lo).clamp(min=1e-6)
|
||||
return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0)
|
||||
272
comfy/ldm/depth_anything_3/ray_pose.py
Normal file
272
comfy/ldm/depth_anything_3/ray_pose.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""Ray-to-pose conversion for the multi-view path of Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# qr/svd use fp32: CUDA often has no fp16/bf16 kernels for these ops.
|
||||
|
||||
|
||||
def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Decompose A = Q @ L with Q orthogonal and L lower-triangular.
|
||||
Implemented in terms of QR by reversing the columns/rows; the standard
|
||||
trick from the upstream reference. Inputs A are (3, 3)."""
|
||||
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device, dtype=A.dtype)
|
||||
A_tilde = A @ P
|
||||
# CUDA QR is not implemented for fp16/bf16; upcast just for this call.
|
||||
Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float())
|
||||
Q_tilde = Q_tilde.to(A.dtype)
|
||||
R_tilde = R_tilde.to(A.dtype)
|
||||
Q = Q_tilde @ P
|
||||
L = P @ R_tilde @ P
|
||||
d = torch.diag(L)
|
||||
sign = torch.sign(d)
|
||||
Q = Q * sign[None, :] # scale columns of Q
|
||||
L = L * sign[:, None] # scale rows of L
|
||||
return Q, L
|
||||
|
||||
|
||||
def _homogenize_points(points: torch.Tensor) -> torch.Tensor:
|
||||
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Weighted-LSQ + RANSAC homography (batched)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _find_homography_weighted_lsq(src_pts: torch.Tensor, dst_pts: torch.Tensor, confident_weight: torch.Tensor,) -> torch.Tensor:
|
||||
"""Solve a single H with weighted least-squares (DLT)."""
|
||||
N = src_pts.shape[0]
|
||||
if N < 4:
|
||||
raise ValueError("At least 4 points are required to compute a homography.")
|
||||
w = confident_weight.sqrt().unsqueeze(1) # (N, 1)
|
||||
x = src_pts[:, 0:1]
|
||||
y = src_pts[:, 1:2]
|
||||
u = dst_pts[:, 0:1]
|
||||
v = dst_pts[:, 1:2]
|
||||
zeros = torch.zeros_like(x)
|
||||
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1)
|
||||
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1)
|
||||
A = torch.cat([A1, A2], dim=0) # (2N, 9)
|
||||
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
|
||||
_, _, Vh = torch.linalg.svd(A.float())
|
||||
Vh = Vh.to(A.dtype)
|
||||
H = Vh[-1].reshape(3, 3)
|
||||
return H / H[-1, -1]
|
||||
|
||||
|
||||
def _find_homography_weighted_lsq_batched(src_pts_batch: torch.Tensor, dst_pts_batch: torch.Tensor, confident_weight_batch: torch.Tensor) -> torch.Tensor:
|
||||
"""Batched DLT solver. Inputs (B, K, 2) / (B, K); output (B, 3, 3)."""
|
||||
B, K, _ = src_pts_batch.shape
|
||||
w = confident_weight_batch.sqrt().unsqueeze(2)
|
||||
x = src_pts_batch[:, :, 0:1]
|
||||
y = src_pts_batch[:, :, 1:2]
|
||||
u = dst_pts_batch[:, :, 0:1]
|
||||
v = dst_pts_batch[:, :, 1:2]
|
||||
zeros = torch.zeros_like(x)
|
||||
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2)
|
||||
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2)
|
||||
A = torch.cat([A1, A2], dim=1) # (B, 2K, 9)
|
||||
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
|
||||
_, _, Vh = torch.linalg.svd(A.float())
|
||||
Vh = Vh.to(A.dtype)
|
||||
H = Vh[:, -1].reshape(B, 3, 3)
|
||||
return H / H[:, 2:3, 2:3]
|
||||
|
||||
|
||||
def _ransac_find_homography_weighted_batched(
|
||||
src_pts: torch.Tensor, # (B, N, 2)
|
||||
dst_pts: torch.Tensor, # (B, N, 2)
|
||||
confident_weight: torch.Tensor, # (B, N)
|
||||
n_sample: int,
|
||||
n_iter: int = 100,
|
||||
reproj_threshold: float = 3.0,
|
||||
num_sample_for_ransac: int = 8,
|
||||
max_inlier_num: int = 10000,
|
||||
rand_sample_iters_idx: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Batched weighted-RANSAC homography estimator. Returns (B, 3, 3) homography matrices."""
|
||||
B, N, _ = src_pts.shape
|
||||
assert N >= 4
|
||||
device = src_pts.device
|
||||
|
||||
sorted_idx = torch.argsort(confident_weight, descending=True, dim=1)
|
||||
candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample)
|
||||
|
||||
if rand_sample_iters_idx is None:
|
||||
rand_sample_iters_idx = torch.stack(
|
||||
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac]
|
||||
for _ in range(n_iter)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k)
|
||||
b_idx = (
|
||||
torch.arange(B, device=device)
|
||||
.view(B, 1, 1)
|
||||
.expand(B, n_iter, num_sample_for_ransac)
|
||||
)
|
||||
src_b = src_pts[b_idx, rand_idx]
|
||||
dst_b = dst_pts[b_idx, rand_idx]
|
||||
w_b = confident_weight[b_idx, rand_idx]
|
||||
|
||||
cB, cN = src_b.shape[:2]
|
||||
H_batch = _find_homography_weighted_lsq_batched(
|
||||
src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1),
|
||||
).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3)
|
||||
|
||||
src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2)
|
||||
proj = torch.bmm(
|
||||
src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3),
|
||||
H_batch.reshape(-1, 3, 3).transpose(1, 2),
|
||||
) # (B*n_iter, N, 3)
|
||||
proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2)
|
||||
err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N)
|
||||
inlier_mask = err < reproj_threshold
|
||||
score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2)
|
||||
best_idx = torch.argmax(score, dim=1)
|
||||
best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx]
|
||||
|
||||
# Refit with the inlier set (per-batch, since the inlier counts vary).
|
||||
H_inlier_list = []
|
||||
for b in range(B):
|
||||
mask = best_inlier_mask[b]
|
||||
in_src = src_pts[b][mask]
|
||||
in_dst = dst_pts[b][mask]
|
||||
in_w = confident_weight[b][mask]
|
||||
if in_src.shape[0] < 4:
|
||||
# Fall back to identity when RANSAC fails to find enough inliers.
|
||||
H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype))
|
||||
continue
|
||||
sorted_w = torch.argsort(in_w, descending=True)
|
||||
if len(sorted_w) > max_inlier_num:
|
||||
keep = max(int(len(sorted_w) * 0.95), max_inlier_num)
|
||||
sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]]
|
||||
H_inlier_list.append(
|
||||
_find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w])
|
||||
)
|
||||
return torch.stack(H_inlier_list, dim=0)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Camera-ray utilities
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _unproject_identity(num_y: int, num_x: int, B: int, S: int, device, dtype) -> torch.Tensor:
|
||||
"""Camera-space unit rays for an identity intrinsic on a 2x2 image plane."""
|
||||
dx = 1.0 / num_x
|
||||
dy = 1.0 / num_y
|
||||
# Centered camera-space coords directly (skip the K^-1 step since it's
|
||||
# just a translation by -1 on x and y when K is identity-with-center=1).
|
||||
y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype)
|
||||
x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype)
|
||||
yy, xx = torch.meshgrid(y, x, indexing="ij")
|
||||
grid = torch.stack((xx, yy), dim=-1) # (h, w, 2)
|
||||
grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2)
|
||||
return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)
|
||||
|
||||
|
||||
def _camray_to_caminfo(
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
confidence: Optional[torch.Tensor] = None, # (B, S, h, w)
|
||||
reproj_threshold: float = 0.2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Convert per-pixel camera rays to per-view (R, T, focal, principal)."""
|
||||
if confidence is None:
|
||||
confidence = torch.ones_like(camray[..., 0])
|
||||
B, S, h, w, _ = camray.shape
|
||||
device = camray.device
|
||||
dtype = camray.dtype
|
||||
|
||||
rays_target = camray[..., :3] # (B, S, h, w, 3)
|
||||
rays_origin = _unproject_identity(h, w, B, S, device, dtype)
|
||||
|
||||
# Flatten (B*S, h*w, *) for the RANSAC routine.
|
||||
rays_target = rays_target.flatten(0, 1).flatten(1, 2)
|
||||
rays_origin = rays_origin.flatten(0, 1).flatten(1, 2)
|
||||
weights = confidence.flatten(0, 1).flatten(1, 2).clone()
|
||||
|
||||
# Project to 2D in homogeneous form (the upstream calls this "perspective division").
|
||||
z_thresh = 1e-4
|
||||
mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh)
|
||||
weights = torch.where(mask, weights, torch.zeros_like(weights))
|
||||
src = rays_origin.clone()
|
||||
dst = rays_target.clone()
|
||||
src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0])
|
||||
src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1])
|
||||
dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0])
|
||||
dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1])
|
||||
src = src[..., :2]
|
||||
dst = dst[..., :2]
|
||||
|
||||
N = src.shape[1]
|
||||
n_iter = 100
|
||||
sample_ratio = 0.3
|
||||
num_sample_for_ransac = 8
|
||||
n_sample = max(num_sample_for_ransac, int(N * sample_ratio))
|
||||
rand_idx = torch.stack(
|
||||
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Chunk along the view axis to keep peak memory predictable.
|
||||
chunk = 2
|
||||
A_list = []
|
||||
for i in range(0, src.shape[0], chunk):
|
||||
A = _ransac_find_homography_weighted_batched(
|
||||
src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk],
|
||||
n_sample=n_sample, n_iter=n_iter,
|
||||
num_sample_for_ransac=num_sample_for_ransac,
|
||||
reproj_threshold=reproj_threshold,
|
||||
rand_sample_iters_idx=rand_idx,
|
||||
max_inlier_num=8000,
|
||||
)
|
||||
# Flip sign on dets that come out < 0 (so that the QL produces a
|
||||
# right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so
|
||||
# do the comparison in fp32.
|
||||
flip = torch.linalg.det(A.float()) < 0
|
||||
A = torch.where(flip[:, None, None], -A, A)
|
||||
A_list.append(A)
|
||||
A = torch.cat(A_list, dim=0) # (B*S, 3, 3)
|
||||
|
||||
R_list, f_list, pp_list = [], [], []
|
||||
for i in range(A.shape[0]):
|
||||
R, L = _ql_decomposition(A[i])
|
||||
L = L / L[2][2]
|
||||
f_list.append(torch.stack((L[0][0], L[1][1])))
|
||||
pp_list.append(torch.stack((L[2][0], L[2][1])))
|
||||
R_list.append(R)
|
||||
R = torch.stack(R_list).reshape(B, S, 3, 3)
|
||||
focal = torch.stack(f_list).reshape(B, S, 2)
|
||||
pp = torch.stack(pp_list).reshape(B, S, 2)
|
||||
|
||||
# Translation: confidence-weighted average of camray direction(s).
|
||||
cf = confidence.flatten(0, 1).flatten(1, 2)
|
||||
T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1)
|
||||
T = T / cf.sum(dim=-1, keepdim=True)
|
||||
T = T.reshape(B, S, 3)
|
||||
|
||||
# Match upstream output convention: focal -> 1/focal, pp + 1.
|
||||
return R, T, 1.0 / focal, pp + 1.0
|
||||
|
||||
|
||||
def get_extrinsic_from_camray(
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w)
|
||||
patch_size_y: int,
|
||||
patch_size_x: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Wrap a 4x4 extrinsic + per-view focal + principal-point output."""
|
||||
if conf.ndim == 5 and conf.shape[-1] == 1:
|
||||
conf = conf.squeeze(-1)
|
||||
R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf)
|
||||
extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4)
|
||||
homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)
|
||||
homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4)
|
||||
extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4)
|
||||
return extr, focal, pp
|
||||
87
comfy/ldm/depth_anything_3/reference_view_selector.py
Normal file
87
comfy/ldm/depth_anything_3/reference_view_selector.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""Reference-view selection for the multi-view path of Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"]
|
||||
|
||||
|
||||
# Per the upstream constants module: ``THRESH_FOR_REF_SELECTION = 3``.
|
||||
# Reference selection only runs when there are at least this many views.
|
||||
THRESH_FOR_REF_SELECTION: int = 3
|
||||
|
||||
|
||||
def select_reference_view(x: torch.Tensor, strategy: RefViewStrategy = "saddle_balanced") -> torch.Tensor:
|
||||
"""Pick a reference view index per batch element."""
|
||||
B, S, _, _ = x.shape
|
||||
if S <= 1:
|
||||
return torch.zeros(B, dtype=torch.long, device=x.device)
|
||||
if strategy == "first":
|
||||
return torch.zeros(B, dtype=torch.long, device=x.device)
|
||||
if strategy == "middle":
|
||||
return torch.full((B,), S // 2, dtype=torch.long, device=x.device)
|
||||
|
||||
# Feature-based strategies: normalised cls/cam token per view.
|
||||
img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # (B,S,C)
|
||||
|
||||
if strategy == "saddle_balanced":
|
||||
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # (B,S,S)
|
||||
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
|
||||
sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # (B,S)
|
||||
feat_norm = x[:, :, 0].norm(dim=-1) # (B,S)
|
||||
feat_var = img_class_feat.var(dim=-1) # (B,S)
|
||||
|
||||
def _normalize(metric):
|
||||
mn = metric.min(dim=1, keepdim=True).values
|
||||
mx = metric.max(dim=1, keepdim=True).values
|
||||
return (metric - mn) / (mx - mn + 1e-8)
|
||||
|
||||
sim_n, norm_n, var_n = _normalize(sim_score), _normalize(feat_norm), _normalize(feat_var)
|
||||
balance = (sim_n - 0.5).abs() + (norm_n - 0.5).abs() + (var_n - 0.5).abs()
|
||||
return balance.argmin(dim=1)
|
||||
|
||||
if strategy == "saddle_sim_range":
|
||||
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2))
|
||||
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
|
||||
sim_max = sim_no_diag.max(dim=-1).values
|
||||
sim_min = sim_no_diag.min(dim=-1).values
|
||||
return (sim_max - sim_min).argmax(dim=1)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown reference view selection strategy: {strategy!r}. "
|
||||
f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'"
|
||||
)
|
||||
|
||||
|
||||
def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
||||
"""Reorder x so the reference view is at position 0 in axis S."""
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
if S <= 1:
|
||||
return x
|
||||
positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
b_idx_exp = b_idx.unsqueeze(1)
|
||||
reorder = torch.where(
|
||||
(positions > 0) & (positions <= b_idx_exp),
|
||||
positions - 1,
|
||||
positions,
|
||||
)
|
||||
reorder[:, 0] = b_idx
|
||||
batch = torch.arange(B, device=x.device).unsqueeze(1)
|
||||
return x[batch, reorder]
|
||||
|
||||
|
||||
def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse of reorder_by_reference."""
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
if S <= 1:
|
||||
return x
|
||||
target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
b_idx_exp = b_idx.unsqueeze(1)
|
||||
restore = torch.where(target_positions < b_idx_exp, target_positions + 1, target_positions)
|
||||
restore = torch.scatter(restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp))
|
||||
batch = torch.arange(B, device=x.device).unsqueeze(1)
|
||||
return x[batch, restore]
|
||||
160
comfy/ldm/depth_anything_3/transform.py
Normal file
160
comfy/ldm/depth_anything_3/transform.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""Geometry / camera transform helpers for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Affine 4x4 helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def as_homogeneous(ext: torch.Tensor) -> torch.Tensor:
|
||||
"""Promote (...,3,4) extrinsics to (...,4,4) homogeneous form. No-op when the input is already ``(...,4,4)``."""
|
||||
if ext.shape[-2:] == (4, 4):
|
||||
return ext
|
||||
if ext.shape[-2:] == (3, 4):
|
||||
ones = torch.zeros_like(ext[..., :1, :4])
|
||||
ones[..., 0, 3] = 1.0
|
||||
return torch.cat([ext, ones], dim=-2)
|
||||
raise ValueError(f"Invalid affine shape: {ext.shape}")
|
||||
|
||||
|
||||
def affine_inverse(A: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse of an affine matrix ``[R|T; 0 0 0 1]``."""
|
||||
R = A[..., :3, :3]
|
||||
T = A[..., :3, 3:]
|
||||
P = A[..., 3:, :]
|
||||
return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Quaternion <-> rotation matrix (xyzw / scalar-last)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
"""sqrt(max(0, x)) with a zero subgradient where x == 0."""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
if torch.is_grad_enabled():
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
else:
|
||||
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
||||
return ret
|
||||
|
||||
|
||||
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""Force the real part of a unit quaternion (xyzw) to be non-negative."""
|
||||
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
||||
|
||||
|
||||
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert quaternions (xyzw) to (...,3,3) rotation matrices."""
|
||||
i, j, k, r = torch.unbind(quaternions, -1)
|
||||
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
||||
o = torch.stack(
|
||||
(
|
||||
1 - two_s * (j * j + k * k),
|
||||
two_s * (i * j - k * r),
|
||||
two_s * (i * k + j * r),
|
||||
two_s * (i * j + k * r),
|
||||
1 - two_s * (i * i + k * k),
|
||||
two_s * (j * k - i * r),
|
||||
two_s * (i * k - j * r),
|
||||
two_s * (j * k + i * r),
|
||||
1 - two_s * (i * i + j * j),
|
||||
),
|
||||
-1,
|
||||
)
|
||||
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
||||
|
||||
|
||||
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert (...,3,3) rotation matrices to quaternions (xyzw)."""
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
||||
|
||||
batch_dim = matrix.shape[:-2]
|
||||
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
||||
matrix.reshape(batch_dim + (9,)), dim=-1
|
||||
)
|
||||
|
||||
q_abs = _sqrt_positive_part(
|
||||
torch.stack(
|
||||
[
|
||||
1.0 + m00 + m11 + m22,
|
||||
1.0 + m00 - m11 - m22,
|
||||
1.0 - m00 + m11 - m22,
|
||||
1.0 - m00 - m11 + m22,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
|
||||
quat_by_rijk = torch.stack(
|
||||
[
|
||||
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
||||
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
||||
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
||||
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
||||
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
||||
|
||||
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
|
||||
batch_dim + (4,)
|
||||
)
|
||||
# Reorder rijk -> xyzw (i.e. ijkr).
|
||||
out = out[..., [1, 2, 3, 0]]
|
||||
return standardize_quaternion(out)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Pose-encoding <-> extrinsics + intrinsics
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extri_intri_to_pose_encoding(extrinsics: torch.Tensor, intrinsics: torch.Tensor, image_size_hw: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Pack (extr, intr, image_size) into the 9-D pose-encoding vector.
|
||||
extrinsics: camera-to-world (c2w) (B,S,4,4) matrices,
|
||||
intrinsics: pixel-space (B,S,3,3) matrices,
|
||||
image_size_hw: is a (H, W) pair.
|
||||
"""
|
||||
R = extrinsics[..., :3, :3]
|
||||
T = extrinsics[..., :3, 3]
|
||||
quat = mat_to_quat(R)
|
||||
H, W = image_size_hw
|
||||
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
|
||||
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
|
||||
return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
|
||||
|
||||
|
||||
def pose_encoding_to_extri_intri(pose_encoding: torch.Tensor, image_size_hw: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Inverse of extri_intri_to_pose_encoding."""
|
||||
T = pose_encoding[..., :3]
|
||||
quat = pose_encoding[..., 3:7]
|
||||
fov_h = pose_encoding[..., 7]
|
||||
fov_w = pose_encoding[..., 8]
|
||||
# Normalize to unit quaternion. CameraDec outputs raw values; a near-zero
|
||||
# quaternion causes two_s = 2/norm² → inf in quat_to_mat → NaN extrinsics.
|
||||
quat = quat / quat.norm(dim=-1, keepdim=True).clamp(min=1e-6)
|
||||
R = quat_to_mat(quat)
|
||||
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
||||
H, W = image_size_hw
|
||||
fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
|
||||
fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
|
||||
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device, dtype=pose_encoding.dtype)
|
||||
intrinsics[..., 0, 0] = fx
|
||||
intrinsics[..., 1, 1] = fy
|
||||
intrinsics[..., 0, 2] = W / 2
|
||||
intrinsics[..., 1, 2] = H / 2
|
||||
intrinsics[..., 2, 2] = 1.0
|
||||
return extrinsics, intrinsics
|
||||
@ -106,11 +106,11 @@ class Ideogram4EmbedScalar(nn.Module):
|
||||
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, dtype):
|
||||
x = x.to(torch.float32)
|
||||
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
|
||||
emb = _sinusoidal_embedding(scaled, self.dim)
|
||||
emb = emb.to(self.mlp_in.weight.dtype)
|
||||
emb = emb.to(dtype)
|
||||
emb = F.silu(self.mlp_in(emb))
|
||||
return self.mlp_out(emb)
|
||||
|
||||
@ -161,7 +161,7 @@ class Ideogram4Transformer(nn.Module):
|
||||
x = x * output_image_mask
|
||||
h = self.input_proj(x) * output_image_mask
|
||||
|
||||
t_cond = self.t_embedding(t)
|
||||
t_cond = self.t_embedding(t, dtype=x.dtype)
|
||||
if t.dim() == 1:
|
||||
t_cond = t_cond.unsqueeze(1)
|
||||
adaln_input = F.silu(self.adaln_proj(t_cond))
|
||||
@ -174,7 +174,7 @@ class Ideogram4Transformer(nn.Module):
|
||||
llm = self.llm_cond_proj(llm) * text_mask
|
||||
h[:, :L_text] = h[:, :L_text] + llm
|
||||
|
||||
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long))
|
||||
h = h + self.embed_image_indicator((indicator == OUTPUT_IMAGE_INDICATOR).to(torch.long), out_dtype=h.dtype)
|
||||
|
||||
# Qwen3-VL interleaved MRoPE; position_ids (B, L, 3) -> (3, L) (same across batch).
|
||||
freqs_cis = precompute_freqs_cis(
|
||||
@ -235,7 +235,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
|
||||
def _run_conditional(self, x_chunk, context_chunk, attn_mask_chunk, t_chunk, gh, gw, transformer_options):
|
||||
B = x_chunk.shape[0]
|
||||
device = x_chunk.device
|
||||
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
|
||||
img_tokens = self._img_to_tokens(x_chunk)
|
||||
L_img = img_tokens.shape[1]
|
||||
L_text = context_chunk.shape[1]
|
||||
L = L_text + L_img
|
||||
@ -268,7 +268,7 @@ class Ideogram4Transformer2DModel(Ideogram4Transformer):
|
||||
def _run_image_only(self, x_chunk, t_chunk, gh, gw, transformer_options):
|
||||
B = x_chunk.shape[0]
|
||||
device = x_chunk.device
|
||||
img_tokens = self._img_to_tokens(x_chunk).to(self.dtype)
|
||||
img_tokens = self._img_to_tokens(x_chunk)
|
||||
L_img = img_tokens.shape[1]
|
||||
|
||||
position_ids = self._image_position_ids(gh, gw, device).unsqueeze(0).expand(B, L_img, 3)
|
||||
|
||||
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from comfy.ldm.lightricks.model import Timesteps
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
@ -17,9 +18,7 @@ def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
||||
return apply_rope1(x, freqs_cis)
|
||||
|
||||
|
||||
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -51,6 +51,18 @@ class FeedForward(nn.Module):
|
||||
return hidden_states
|
||||
|
||||
|
||||
# Addin this back because Nunchaku custom nodes rely on it, see comment here:
|
||||
# https://github.com/Comfy-Org/ComfyUI/pull/14178#issuecomment-4640475161
|
||||
# TODO: Eventually remove this once we natively support SVDQuants
|
||||
def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape)
|
||||
|
||||
|
||||
class QwenTimestepProjEmbeddings(nn.Module):
|
||||
def __init__(self, embedding_dim, pooled_projection_dim, use_additional_t_cond=False, dtype=None, device=None, operations=None):
|
||||
super().__init__()
|
||||
|
||||
@ -8,7 +8,7 @@ from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.math import apply_rope1, rope
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -570,6 +570,14 @@ class WanModel(torch.nn.Module):
|
||||
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||
x = torch.concat((full_ref, x), dim=1)
|
||||
|
||||
# In-context reference (Bernini)
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
main_len = x.shape[1]
|
||||
if context_latents is not None:
|
||||
for lat in context_latents:
|
||||
cl = self.patch_embedding(lat.float().to(x.device)).to(x.dtype).flatten(2).transpose(1, 2)
|
||||
x = torch.cat([x, cl], dim=1)
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
@ -599,6 +607,9 @@ class WanModel(torch.nn.Module):
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if context_latents is not None:
|
||||
x = x[:, :main_len]
|
||||
|
||||
if full_ref is not None:
|
||||
x = x[:, full_ref.shape[1]:]
|
||||
|
||||
@ -606,7 +617,7 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
@ -638,6 +649,13 @@ class WanModel(torch.nn.Module):
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
|
||||
# In-context reference: a non-zero source_id composes an extra rotation into the spatial rope
|
||||
if source_id:
|
||||
d = self.dim // self.num_heads
|
||||
pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32)
|
||||
id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype)
|
||||
freqs = torch.einsum('...ij,...jk->...ik', freqs, id_rot)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
@ -661,6 +679,15 @@ class WanModel(torch.nn.Module):
|
||||
t_len += 1
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
# In-context reference: one rope block per stream, each with it's own source_id (1, 2, ...) to distinguish from the target (id 0).
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
if context_latents is not None:
|
||||
context_latents = [comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size) for lat in context_latents]
|
||||
for i, lat in enumerate(context_latents):
|
||||
freqs = torch.cat([freqs, self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=i + 1)], dim=1)
|
||||
kwargs = {**kwargs, "context_latents": context_latents}
|
||||
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
@ -1631,13 +1658,15 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||
|
||||
if reference_latent is not None:
|
||||
x = torch.cat((reference_latent, x), dim=2)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if ref_mask_latents is not None: # SCAIL-2 additive mask stream
|
||||
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
@ -1645,6 +1674,8 @@ class SCAILWanModel(WanModel):
|
||||
scail_pose_seq_len = 0
|
||||
if pose_latents is not None:
|
||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||
if sam_latents is not None: # SCAIL-2 additive mask stream
|
||||
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
|
||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||
scail_pose_seq_len = scail_x.shape[1]
|
||||
x = torch.cat([x, scail_x], dim=1)
|
||||
@ -1695,7 +1726,36 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
||||
# ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode,
|
||||
# which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset.
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
|
||||
if ref_mask_flag is not None and not bool(ref_mask_flag):
|
||||
REF_ROPE_H = 120.0
|
||||
POSE_ROPE_W = 120.0
|
||||
|
||||
ref_t_patches = 0
|
||||
if reference_latent is not None:
|
||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||
main_t_patches = t - ref_t_patches
|
||||
|
||||
parts = []
|
||||
if ref_t_patches > 0:
|
||||
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
|
||||
parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
|
||||
if main_t_patches > 0:
|
||||
parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options))
|
||||
|
||||
if pose_latents is not None:
|
||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||
h_scale = h / H_pose
|
||||
w_scale = w / W_pose
|
||||
h_shift = (h_scale - 1) / 2
|
||||
w_shift = (w_scale - 1) / 2
|
||||
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf))
|
||||
|
||||
return torch.cat(parts, dim=1)
|
||||
|
||||
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||
|
||||
if pose_latents is None:
|
||||
@ -1719,12 +1779,16 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if pose_latents is not None:
|
||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||
if ref_mask_latents is not None: # SCAIL-2
|
||||
ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size)
|
||||
if sam_latents is not None: # SCAIL-2
|
||||
sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
@ -1737,5 +1801,15 @@ class SCAILWanModel(WanModel):
|
||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||
t_len += reference_latent.shape[2]
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||
ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
|
||||
class SCAIL2WanModel(SCAILWanModel):
|
||||
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
||||
|
||||
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
@ -357,6 +357,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, (comfy.model_base.LTXV, comfy.model_base.LTXAV)):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@ -65,6 +65,7 @@ import comfy.ldm.ernie.model
|
||||
import comfy.ldm.sam3.detector
|
||||
import comfy.ldm.hidream_o1.model
|
||||
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
|
||||
import comfy.ldm.depth_anything_3.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1518,8 +1519,26 @@ class WAN21(BaseModel):
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
|
||||
|
||||
# In-context reference conditioning (Bernini)
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
if context_latents is not None:
|
||||
out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents])
|
||||
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
# In-context cond slicing (Bernini)
|
||||
if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list):
|
||||
dim = window.dim
|
||||
out = []
|
||||
for lat in cond_value.cond:
|
||||
if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]:
|
||||
out.append(window.get_tensor(lat, device, dim=dim, retain_index_list=retain_index_list))
|
||||
else:
|
||||
out.append(lat.to(device))
|
||||
return cond_value._copy_with(out)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
|
||||
class WAN21_CausalAR(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
@ -1754,6 +1773,97 @@ class WAN21_SCAIL(WAN21):
|
||||
|
||||
return out
|
||||
|
||||
class WAN21_SCAIL2(WAN21_SCAIL):
|
||||
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
||||
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel)
|
||||
self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents")
|
||||
self.memory_usage_shape_process = {
|
||||
"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
|
||||
"sam_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
|
||||
}
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
|
||||
if driving_mask_28ch is not None:
|
||||
out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous())
|
||||
|
||||
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
|
||||
if ref_mask_28ch is not None:
|
||||
out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous())
|
||||
|
||||
ref_mask_flag = kwargs.get("ref_mask_flag", None)
|
||||
if ref_mask_flag is not None:
|
||||
out['ref_mask_flag'] = comfy.conds.CONDConstant(ref_mask_flag)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = super().extra_conds_shapes(**kwargs)
|
||||
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
|
||||
if driving_mask_28ch is not None:
|
||||
s = driving_mask_28ch.shape
|
||||
out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]]
|
||||
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
|
||||
if ref_mask_28ch is not None:
|
||||
s = ref_mask_28ch.shape
|
||||
out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]]
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key in ("sam_latents", "pose_latents"):
|
||||
# Return sliced view omitting retain_index_list
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=0)
|
||||
if cond_key == "ref_mask_latents" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
# The ref mask is just a single frame padded with frames of zeros, so just grab the first frames for all windows
|
||||
full_ref_mask = cond_value.cond
|
||||
video_frame_count = x_in.shape[2]
|
||||
if full_ref_mask.shape[2] != video_frame_count + 1:
|
||||
return None
|
||||
window_length = len(window.index_list)
|
||||
|
||||
# Account for the causal anchor frame if it exists
|
||||
anchor_index = getattr(window, "causal_anchor_index", None)
|
||||
if anchor_index is not None and anchor_index >= 0:
|
||||
window_length += 1
|
||||
|
||||
window_ref_mask = full_ref_mask[:, :, :window_length + 1].to(device)
|
||||
return cond_value._copy_with(window_ref_mask)
|
||||
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
# The 4 extra channels are the history_mask (1 at clean-anchor frames).
|
||||
noise = kwargs.get("noise", None)
|
||||
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
|
||||
if extra_channels != 4:
|
||||
return super().concat_cond(**kwargs)
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
return torch.zeros_like(noise)[:, :4]
|
||||
|
||||
device = kwargs["device"]
|
||||
if mask.shape[1] != 4:
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
if mask.shape[1] == 1:
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
return mask
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
# Hold anchor constant across all sigmas instead of base sigma*noise + (1-sigma)*latent_image.
|
||||
return latent_image
|
||||
|
||||
|
||||
class WAN22_WanDancer(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel)
|
||||
@ -2227,6 +2337,12 @@ class RT_DETR_v4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||
|
||||
|
||||
class DepthAnything3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net)
|
||||
|
||||
class ErnieImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)
|
||||
|
||||
@ -630,6 +630,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "humo"
|
||||
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "animate"
|
||||
elif '{}patch_embedding_mask.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "scail2"
|
||||
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "scail"
|
||||
elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys:
|
||||
@ -860,6 +862,95 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
# Depth Anything 3 (repackaged to ComfyUI's native Dinov2Model layout via scripts/convert_da3.py)
|
||||
if '{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "DepthAnything3"
|
||||
|
||||
patch_w = state_dict['{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix)]
|
||||
embed_dim = patch_w.shape[0]
|
||||
depth = count_blocks(state_dict_keys, '{}backbone.encoder.layer.'.format(key_prefix) + '{}.')
|
||||
|
||||
# Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg).
|
||||
backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim)
|
||||
if backbone_name is None:
|
||||
return None
|
||||
dit_config["backbone_name"] = backbone_name
|
||||
|
||||
# Detect DA3 extensions on top of vanilla DINOv2.
|
||||
has_camera_token = '{}backbone.embeddings.camera_token'.format(key_prefix) in state_dict_keys
|
||||
# qk-norm shows up as `attention.q_norm.weight` on enabled blocks.
|
||||
qknorm_indices = [
|
||||
i for i in range(depth)
|
||||
if '{}backbone.encoder.layer.{}.attention.q_norm.weight'.format(key_prefix, i) in state_dict_keys
|
||||
]
|
||||
qknorm_start = qknorm_indices[0] if qknorm_indices else -1
|
||||
|
||||
# The DA3 main-series configs always set alt_start == qknorm_start == rope_start.
|
||||
# cat_token=True is implied by the presence of camera_token.
|
||||
if has_camera_token:
|
||||
dit_config["alt_start"] = qknorm_start
|
||||
dit_config["rope_start"] = qknorm_start
|
||||
dit_config["qknorm_start"] = qknorm_start
|
||||
dit_config["cat_token"] = True
|
||||
else:
|
||||
dit_config["alt_start"] = -1
|
||||
dit_config["rope_start"] = -1
|
||||
dit_config["qknorm_start"] = -1
|
||||
dit_config["cat_token"] = False
|
||||
|
||||
# Detect head type and config.
|
||||
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["head_dim_in"] = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["head_features"] = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["head_out_channels"] = [
|
||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
||||
for i in range(4)
|
||||
]
|
||||
if has_aux:
|
||||
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
|
||||
dit_config["head_type"] = "dualdpt"
|
||||
dit_config["head_output_dim"] = 2
|
||||
dit_config["head_use_sky_head"] = False
|
||||
else:
|
||||
dit_config["head_type"] = "dpt"
|
||||
dit_config["head_output_dim"] = state_dict[
|
||||
'{}head.scratch.output_conv2.2.weight'.format(key_prefix)
|
||||
].shape[0]
|
||||
dit_config["head_use_sky_head"] = (
|
||||
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
|
||||
)
|
||||
|
||||
# out_layers: hard-coded per upstream YAML config (depth-aware default).
|
||||
if depth >= 24:
|
||||
# vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT).
|
||||
if has_aux:
|
||||
dit_config["out_layers"] = [11, 15, 19, 23]
|
||||
else:
|
||||
dit_config["out_layers"] = [4, 11, 17, 23]
|
||||
else:
|
||||
# vits/vitb: 12 blocks
|
||||
dit_config["out_layers"] = [5, 7, 9, 11]
|
||||
|
||||
# Camera encoder/decoder presence (multi-view + pose path).
|
||||
has_cam_enc = '{}cam_enc.token_norm.weight'.format(key_prefix) in state_dict_keys
|
||||
has_cam_dec = '{}cam_dec.fc_t.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["has_cam_enc"] = has_cam_enc
|
||||
dit_config["has_cam_dec"] = has_cam_dec
|
||||
if has_cam_enc:
|
||||
cam_enc_w = state_dict.get(
|
||||
'{}cam_enc.pose_branch.fc2.weight'.format(key_prefix)
|
||||
)
|
||||
if cam_enc_w is not None:
|
||||
dit_config["cam_dim_out"] = cam_enc_w.shape[0]
|
||||
if has_cam_dec:
|
||||
cam_dec_w = state_dict.get(
|
||||
'{}cam_dec.fc_t.weight'.format(key_prefix)
|
||||
)
|
||||
if cam_dec_w is not None:
|
||||
dit_config["cam_dec_dim_in"] = cam_dec_w.shape[1]
|
||||
return dit_config
|
||||
|
||||
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ernie"
|
||||
|
||||
@ -534,8 +534,10 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def set_cudnn_benchmark():
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available():
|
||||
torch.backends.cudnn.benchmark = PerformanceFeature.AutoTune in args.fast
|
||||
|
||||
try:
|
||||
if torch_version_numeric >= (2, 5):
|
||||
@ -641,6 +643,8 @@ def free_pins(size, evict_active=False):
|
||||
return freed_total
|
||||
|
||||
def ensure_pin_budget(size, evict_active=False):
|
||||
if args.high_ram:
|
||||
return True
|
||||
if args.fast_disk:
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
else:
|
||||
@ -651,8 +655,7 @@ def ensure_pin_budget(size, evict_active=False):
|
||||
to_free = shortfall + PIN_PRESSURE_HYSTERESIS
|
||||
return free_pins(to_free, evict_active=evict_active) >= shortfall
|
||||
|
||||
def ensure_pin_registerable(size, evict_active=True):
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
def free_registrations(shortfall, evict_active=True):
|
||||
if MAX_PINNED_MEMORY <= 0:
|
||||
return False
|
||||
if shortfall <= 0:
|
||||
@ -674,6 +677,9 @@ def ensure_pin_registerable(size, evict_active=True):
|
||||
return True
|
||||
return shortfall <= REGISTERABLE_PIN_HYSTERESIS
|
||||
|
||||
def ensure_pin_registerable(size, evict_active=True):
|
||||
return free_registrations(TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY, evict_active=evict_active)
|
||||
|
||||
class LoadedModel:
|
||||
def __init__(self, model: ModelPatcher):
|
||||
self._set_model(model)
|
||||
@ -956,8 +962,6 @@ def loaded_models(only_currently_used=False):
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
@ -1494,6 +1498,8 @@ if not args.disable_pinned_memory:
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
|
||||
def pinned_hostbuf_size(size):
|
||||
if args.high_ram:
|
||||
return max(0, int(size * 2))
|
||||
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
|
||||
|
||||
def discard_cuda_async_error():
|
||||
|
||||
@ -379,10 +379,11 @@ class ModelPatcher:
|
||||
def get_clone_model_override(self):
|
||||
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
|
||||
|
||||
def clone(self, disable_dynamic=False, model_override=None):
|
||||
def clone(self, disable_dynamic=False, model_override=None, force_deepcopy=False):
|
||||
class_ = self.__class__
|
||||
if self.is_dynamic() and disable_dynamic:
|
||||
class_ = ModelPatcher
|
||||
if self.is_dynamic() and disable_dynamic or force_deepcopy:
|
||||
if self.is_dynamic() and disable_dynamic:
|
||||
class_ = ModelPatcher
|
||||
if model_override is None:
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||
|
||||
18
comfy/ops.py
18
comfy/ops.py
@ -180,7 +180,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch([pin], dest, offload_stream)
|
||||
return
|
||||
if signature is None:
|
||||
if signature is None or args.high_ram:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
||||
@ -299,21 +299,21 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
if hasattr(s, "_v") and comfy.model_management.is_device_cpu(device):
|
||||
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||
|
||||
elif hasattr(s, "_v") and s.weight.device != device:
|
||||
prefetched = hasattr(s, "_prefetch")
|
||||
offload_stream = None
|
||||
offload_device = None
|
||||
|
||||
@ -89,13 +89,26 @@ def pin_memory(module, subset="weights", size=None):
|
||||
not comfy.model_management.ensure_pin_registerable(registerable_size)):
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
|
||||
extended = False
|
||||
try:
|
||||
hostbuf.extend(size=size)
|
||||
hostbuf.extend(size=size, register=False)
|
||||
extended = True
|
||||
pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||
pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
comfy.model_management.free_registrations(size)
|
||||
if torch.cuda.cudart().cudaHostRegister(pin.data_ptr(), size, 1) != 0:
|
||||
comfy.model_management.discard_cuda_async_error()
|
||||
del pin
|
||||
hostbuf.truncate(offset, do_unregister=False)
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
except RuntimeError:
|
||||
if extended:
|
||||
hostbuf.truncate(offset, do_unregister=False)
|
||||
return _steal_pin(module, stack, buckets, size, priority)
|
||||
|
||||
module._pin = comfy_aimdo.torch.hostbuf_to_tensor(hostbuf)[offset:offset + size]
|
||||
module._pin.untyped_storage()._comfy_hostbuf = hostbuf
|
||||
module._pin = pin
|
||||
stack.append((module, offset))
|
||||
module._pin_registered = True
|
||||
module._pin_stack_index = len(stack) - 1
|
||||
|
||||
@ -1450,6 +1450,17 @@ class WAN21_SCAIL(WAN21_T2V):
|
||||
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
|
||||
class WAN21_SCAIL2(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "scail2",
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_SCAIL2(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_WanDancer(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -2045,6 +2056,23 @@ class RT_DETR_v4(supported_models_base.BASE):
|
||||
return None
|
||||
|
||||
|
||||
class DepthAnything3(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "DepthAnything3",
|
||||
}
|
||||
|
||||
# Mono path: no num_heads / num_head_channels needed.
|
||||
unet_extra_config = {}
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.DepthAnything3(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
|
||||
class ErnieImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ernie",
|
||||
@ -2259,6 +2287,7 @@ models = [
|
||||
WAN22_Animate,
|
||||
WAN21_FlowRVS,
|
||||
WAN21_SCAIL,
|
||||
WAN21_SCAIL2,
|
||||
WAN22_WanDancer,
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
@ -2286,4 +2315,5 @@ models = [
|
||||
CogVideoX_I2V,
|
||||
CogVideoX_T2V,
|
||||
SVD_img2vid,
|
||||
DepthAnything3,
|
||||
]
|
||||
|
||||
@ -32,7 +32,9 @@ class Ideogram4Tokenizer(sd1_clip.SD1Tokenizer):
|
||||
self.llama_template = "<|im_start|>user\n{}<|im_end|>\n<|im_start|>assistant\n"
|
||||
|
||||
def tokenize_with_weights(self, text, return_word_ids=False, llama_template=None, **kwargs):
|
||||
if llama_template is None:
|
||||
if text.startswith('<|im_start|>'):
|
||||
llama_text = text
|
||||
elif llama_template is None:
|
||||
llama_text = self.llama_template.format(text)
|
||||
else:
|
||||
llama_text = llama_template.format(text)
|
||||
|
||||
@ -27,10 +27,13 @@ class VideoInput(ABC):
|
||||
path: Union[str, IO[bytes]],
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
|
||||
bit_depth selects the encoded bit depth; None keeps the video's native depth.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -83,6 +86,14 @@ class VideoInput(ABC):
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
"""
|
||||
Returns the bit depth of the video (e.g. 8 or 10).
|
||||
|
||||
Default implementation returns 8; subclasses report their real depth.
|
||||
"""
|
||||
return 8
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
@ -52,6 +52,12 @@ def get_open_write_kwargs(
|
||||
return open_kwargs
|
||||
|
||||
|
||||
def video_stream_bit_depth(stream) -> int:
|
||||
if stream is None or stream.format is None or not stream.format.components:
|
||||
return 8
|
||||
return max(component.bits for component in stream.format.components)
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
@ -97,6 +103,13 @@ class VideoFromFile(VideoInput):
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||
return video_stream_bit_depth(video_stream)
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
@ -257,6 +270,7 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
image_format = 'gbrpf32le'
|
||||
process_image_format = lambda a: a
|
||||
align_graph = None
|
||||
audio = None
|
||||
|
||||
streams = [video_stream]
|
||||
@ -310,7 +324,24 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
checked_alpha = True
|
||||
|
||||
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
|
||||
# Fix non-deterministic video decode when the video width is not a multiple of 32
|
||||
# For non-yuvj pixel formats (all H.264/H.265 video)
|
||||
if image_format in ('gbrpf32le', 'gbrapf32le') and frame.width % 32 != 0:
|
||||
if align_graph is None:
|
||||
pad_w = ((frame.width + 31) // 32) * 32
|
||||
g = av.filter.Graph()
|
||||
g_src = g.add_buffer(width=frame.width, height=frame.height,
|
||||
format=frame.format.name, time_base=video_stream.time_base)
|
||||
g_pad = g.add('pad', f'{pad_w}:{frame.height}:0:0')
|
||||
g_sink = g.add('buffersink')
|
||||
g_src.link_to(g_pad)
|
||||
g_pad.link_to(g_sink)
|
||||
g.configure()
|
||||
align_graph = (g, g_src, g_sink)
|
||||
align_graph[1].push(frame)
|
||||
img = np.ascontiguousarray(align_graph[2].pull().to_ndarray(format=image_format)[:, :frame.width])
|
||||
else:
|
||||
img = frame.to_ndarray(format=image_format)
|
||||
if frame.rotation != 0:
|
||||
k = int(round(frame.rotation // 90))
|
||||
img = np.rot90(img, k=k, axes=(0, 1)).copy()
|
||||
@ -377,25 +408,32 @@ class VideoFromFile(VideoInput):
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||
video_encoding = video_stream.codec.name if video_stream is not None else None
|
||||
source_bit_depth = video_stream_bit_depth(video_stream)
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth:
|
||||
reuse_streams = False
|
||||
if self.__start_time or self.__duration:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
if bit_depth is None:
|
||||
bit_depth = source_bit_depth
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth,
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@ -451,8 +489,10 @@ class VideoFromComponents(VideoInput):
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
def __init__(self, components: VideoComponents, bit_depth: int = 8):
|
||||
self.__components = components
|
||||
# Tensor components have no inherent bit depth; this is the depth used when encoding.
|
||||
self.__bit_depth = bit_depth
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
@ -461,18 +501,26 @@ class VideoFromComponents(VideoInput):
|
||||
frame_rate=self.__components.frame_rate,
|
||||
)
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
return self.__bit_depth
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
"""Save the video to a file path or BytesIO buffer."""
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
# None means "use the depth this video was created with" (CreateVideo's choice).
|
||||
if bit_depth is None:
|
||||
bit_depth = self.__bit_depth
|
||||
is_10bit = bit_depth >= 10
|
||||
extra_kwargs = {}
|
||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||
extra_kwargs["format"] = format.value
|
||||
@ -488,10 +536,11 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
pix_fmt = "yuv420p10le" if is_10bit else "yuv420p"
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
video_stream.pix_fmt = pix_fmt
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
@ -505,9 +554,14 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
if is_10bit:
|
||||
# 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
|
||||
img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format="rgb48le")
|
||||
else:
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format=pix_fmt)
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
|
||||
@ -755,6 +755,18 @@ class File3DKSPLAT(ComfyTypeIO):
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_SPLAT_ANY")
|
||||
class File3DSplatAny(ComfyTypeIO):
|
||||
"""General 3D Gaussian splat file type - accepts any supported splat container (.ply / .spz / .splat / .ksplat)."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="FILE_3D_POINT_CLOUD_ANY")
|
||||
class File3DPointCloudAny(ComfyTypeIO):
|
||||
"""General point cloud file type - accepts any supported point cloud container (currently .ply)."""
|
||||
Type = File3D
|
||||
|
||||
|
||||
@comfytype(io_type="HOOKS")
|
||||
class Hooks(ComfyTypeIO):
|
||||
if TYPE_CHECKING:
|
||||
@ -1388,7 +1400,8 @@ class V3Data(TypedDict):
|
||||
class HiddenHolder:
|
||||
def __init__(self, unique_id: str, prompt: Any,
|
||||
extra_pnginfo: Any, dynprompt: Any,
|
||||
auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs):
|
||||
auth_token_comfy_org: str, api_key_comfy_org: str,
|
||||
comfy_usage_source: str = None, **kwargs):
|
||||
self.unique_id = unique_id
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
self.prompt = prompt
|
||||
@ -1401,6 +1414,8 @@ class HiddenHolder:
|
||||
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||
self.api_key_comfy_org = api_key_comfy_org
|
||||
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||
self.comfy_usage_source = comfy_usage_source
|
||||
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
'''If hidden variable not found, return None.'''
|
||||
@ -1417,6 +1432,7 @@ class HiddenHolder:
|
||||
dynprompt=d.get(Hidden.dynprompt, None),
|
||||
auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None),
|
||||
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
||||
comfy_usage_source=d.get(Hidden.comfy_usage_source, None),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1439,6 +1455,8 @@ class Hidden(str, Enum):
|
||||
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||
api_key_comfy_org = "API_KEY_COMFY_ORG"
|
||||
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||
comfy_usage_source = "COMFY_USAGE_SOURCE"
|
||||
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1642,6 +1660,8 @@ class Schema:
|
||||
self.hidden.append(Hidden.auth_token_comfy_org)
|
||||
if Hidden.api_key_comfy_org not in self.hidden:
|
||||
self.hidden.append(Hidden.api_key_comfy_org)
|
||||
if Hidden.comfy_usage_source not in self.hidden:
|
||||
self.hidden.append(Hidden.comfy_usage_source)
|
||||
# if is an output_node, will need prompt and extra_pnginfo
|
||||
if self.is_output_node:
|
||||
if Hidden.prompt not in self.hidden:
|
||||
@ -2336,6 +2356,8 @@ __all__ = [
|
||||
"File3DSPLAT",
|
||||
"File3DSPZ",
|
||||
"File3DKSPLAT",
|
||||
"File3DSplatAny",
|
||||
"File3DPointCloudAny",
|
||||
"Hooks",
|
||||
"HookKeyframes",
|
||||
"TimestepsRange",
|
||||
|
||||
@ -285,7 +285,7 @@ class AudioSaveHelper:
|
||||
results = []
|
||||
for batch_number, waveform in enumerate(audio["waveform"].cpu()):
|
||||
filename_with_batch_num = filename.replace("%batch_num%", str(batch_number))
|
||||
file = f"{filename_with_batch_num}_{counter:05}_.{format}"
|
||||
file = f"{filename_with_batch_num}_{counter:05}.{format}"
|
||||
output_path = os.path.join(full_output_folder, file)
|
||||
|
||||
# Use original sample rate initially
|
||||
|
||||
9
comfy_api_nodes/apis/__init__.py
generated
9
comfy_api_nodes/apis/__init__.py
generated
@ -1310,13 +1310,6 @@ class KlingTaskStatus(str, Enum):
|
||||
failed = 'failed'
|
||||
|
||||
|
||||
class KlingTextToVideoModelName(str, Enum):
|
||||
kling_v1 = 'kling-v1'
|
||||
kling_v1_6 = 'kling-v1-6'
|
||||
kling_v2_1_master = 'kling-v2-1-master'
|
||||
kling_v2_5_turbo = 'kling-v2-5-turbo'
|
||||
|
||||
|
||||
class KlingVideoGenAspectRatio(str, Enum):
|
||||
field_16_9 = '16:9'
|
||||
field_9_16 = '9:16'
|
||||
@ -5179,7 +5172,7 @@ class KlingText2VideoRequest(BaseModel):
|
||||
duration: Optional[KlingVideoGenDuration] = '5'
|
||||
external_task_id: Optional[str] = Field(None, description='Customized Task ID')
|
||||
mode: Optional[KlingVideoGenMode] = 'std'
|
||||
model_name: Optional[KlingTextToVideoModelName] = 'kling-v1'
|
||||
model_name: Optional[str] = 'kling-v1'
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative text prompt', max_length=2500
|
||||
)
|
||||
|
||||
@ -43,6 +43,7 @@ class BFLFluxEraseRequest(BaseModel):
|
||||
"white (255) marks areas to remove, black (0) marks areas to preserve.",
|
||||
)
|
||||
dilate_pixels: int = Field(10)
|
||||
seed: int | None = Field(None)
|
||||
output_format: str = Field("png")
|
||||
|
||||
|
||||
|
||||
@ -97,3 +97,28 @@ class BriaRemoveVideoBackgroundResult(BaseModel):
|
||||
class BriaRemoveVideoBackgroundResponse(BaseModel):
|
||||
status: str = Field(...)
|
||||
result: BriaRemoveVideoBackgroundResult | None = Field(None)
|
||||
|
||||
|
||||
class BriaVideoGreenScreenRequest(BaseModel):
|
||||
video: str = Field(..., description="Publicly accessible URL of the input video.")
|
||||
green_shade: str = Field(
|
||||
default="broadcast_green",
|
||||
description="Solid chroma-key shade applied behind the foreground "
|
||||
"(broadcast_green, chroma_green, or blue_screen).",
|
||||
)
|
||||
output_container_and_codec: str = Field(...)
|
||||
preserve_audio: bool = Field(True)
|
||||
seed: int = Field(...)
|
||||
|
||||
|
||||
class BriaVideoReplaceBackgroundRequest(BaseModel):
|
||||
video: str = Field(..., description="Publicly accessible URL of the input (foreground) video.")
|
||||
background_url: str = Field(
|
||||
...,
|
||||
description="Publicly accessible URL of the background image or video to composite behind "
|
||||
"the foreground. Stretched to the foreground frame; match its aspect ratio for "
|
||||
"undistorted results.",
|
||||
)
|
||||
output_container_and_codec: str = Field(...)
|
||||
preserve_audio: bool = Field(True)
|
||||
seed: int = Field(...)
|
||||
|
||||
@ -108,13 +108,19 @@ class GeminiVideoMetadata(BaseModel):
|
||||
startOffset: GeminiOffset | None = Field(None)
|
||||
|
||||
|
||||
class GeminiThinkingConfig(BaseModel):
|
||||
includeThoughts: bool | None = Field(None)
|
||||
thinkingLevel: str = Field(...)
|
||||
|
||||
|
||||
class GeminiGenerationConfig(BaseModel):
|
||||
maxOutputTokens: int | None = Field(None, ge=16, le=8192)
|
||||
maxOutputTokens: int | None = Field(None, ge=16, le=65536)
|
||||
seed: int | None = Field(None)
|
||||
stopSequences: list[str] | None = Field(None)
|
||||
temperature: float | None = Field(None, ge=0.0, le=2.0)
|
||||
topK: int | None = Field(None, ge=1)
|
||||
topP: float | None = Field(None, ge=0.0, le=1.0)
|
||||
thinkingConfig: GeminiThinkingConfig | None = Field(None)
|
||||
|
||||
|
||||
class GeminiImageOutputOptions(BaseModel):
|
||||
@ -128,11 +134,6 @@ class GeminiImageConfig(BaseModel):
|
||||
imageOutputOptions: GeminiImageOutputOptions = Field(default_factory=GeminiImageOutputOptions)
|
||||
|
||||
|
||||
class GeminiThinkingConfig(BaseModel):
|
||||
includeThoughts: bool | None = Field(None)
|
||||
thinkingLevel: str = Field(...)
|
||||
|
||||
|
||||
class GeminiImageGenerationConfig(GeminiGenerationConfig):
|
||||
responseModalities: list[str] | None = Field(None)
|
||||
imageConfig: GeminiImageConfig | None = Field(None)
|
||||
|
||||
@ -67,15 +67,6 @@ class RunwayImageToVideoResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
|
||||
|
||||
class RunwayTaskStatusEnum(str, Enum):
|
||||
SUCCEEDED = 'SUCCEEDED'
|
||||
RUNNING = 'RUNNING'
|
||||
FAILED = 'FAILED'
|
||||
PENDING = 'PENDING'
|
||||
CANCELLED = 'CANCELLED'
|
||||
THROTTLED = 'THROTTLED'
|
||||
|
||||
|
||||
class RunwayTaskStatusResponse(BaseModel):
|
||||
createdAt: datetime = Field(..., description='Task creation timestamp')
|
||||
id: str = Field(..., description='Task ID')
|
||||
@ -86,7 +77,7 @@ class RunwayTaskStatusResponse(BaseModel):
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
status: RunwayTaskStatusEnum
|
||||
status: str = Field(..., description="SUCCEEDED, RUNNING, FAILED, PENDING, CANCELLED or THROTTLED")
|
||||
|
||||
|
||||
class Model4(str, Enum):
|
||||
@ -125,3 +116,144 @@ class RunwayTextToImageRequest(BaseModel):
|
||||
|
||||
class RunwayTextToImageResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
|
||||
|
||||
class RunwayAleph2IO:
|
||||
"""Custom socket types for chaining Aleph2 guidance images."""
|
||||
|
||||
KEYFRAME = "RUNWAY_ALEPH2_KEYFRAME"
|
||||
PROMPT_IMAGE = "RUNWAY_ALEPH2_PROMPT_IMAGE"
|
||||
|
||||
|
||||
# Keyframe timing modes (anchored to the INPUT video). Stored on the chain item and used to
|
||||
# choose the request model below. The values match the Aleph2 keyframe union field names.
|
||||
KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the input video
|
||||
KEYFRAME_MODE_AT = "at" # fraction [0.0, 1.0] of the input video duration
|
||||
|
||||
# Prompt-image position modes (anchored to the OUTPUT video). Values match the Aleph2 position `type`.
|
||||
PROMPT_IMAGE_MODE_TIMESTAMP = "timestamp" # absolute time, in seconds, from the start of the output video
|
||||
PROMPT_IMAGE_MODE_POSITION = "position" # fraction [0.0, 1.0] of the output video duration
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeItem:
|
||||
"""A guidance image anchored to a point of the INPUT video (one Aleph2 ``keyframe``)."""
|
||||
|
||||
def __init__(self, image, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # KEYFRAME_MODE_SECONDS | KEYFRAME_MODE_AT
|
||||
self.value = value
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeChain:
|
||||
"""An ordered collection of keyframes, built by chaining Runway Aleph2 Keyframe nodes."""
|
||||
|
||||
def __init__(self):
|
||||
self.items: list[RunwayAleph2KeyframeItem] = []
|
||||
|
||||
def add(self, item: RunwayAleph2KeyframeItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "RunwayAleph2KeyframeChain":
|
||||
c = RunwayAleph2KeyframeChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageItem:
|
||||
"""A guidance image anchored to a point of the OUTPUT video (one Aleph2 ``promptImage``)."""
|
||||
|
||||
def __init__(self, image, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # PROMPT_IMAGE_MODE_TIMESTAMP | PROMPT_IMAGE_MODE_POSITION
|
||||
self.value = value
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageChain:
|
||||
"""An ordered collection of prompt images, built by chaining Runway Aleph2 Prompt Image nodes."""
|
||||
|
||||
def __init__(self):
|
||||
self.items: list[RunwayAleph2PromptImageItem] = []
|
||||
|
||||
def add(self, item: RunwayAleph2PromptImageItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "RunwayAleph2PromptImageChain":
|
||||
c = RunwayAleph2PromptImageChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeSeconds(BaseModel):
|
||||
seconds: float = Field(
|
||||
...,
|
||||
description="Absolute timestamp in seconds from the start of the input video when this guidance image should apply.",
|
||||
ge=0.0,
|
||||
)
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeAt(BaseModel):
|
||||
at: float = Field(
|
||||
...,
|
||||
description="Position as a fraction [0.0, 1.0] of the input video duration.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2TimestampPosition(BaseModel):
|
||||
type: str = Field(default="timestamp")
|
||||
timestampSeconds: float = Field(
|
||||
...,
|
||||
description="Absolute timestamp in seconds from the start of the output video.",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2RelativePosition(BaseModel):
|
||||
type: str = Field(default="position")
|
||||
positionPercentage: float = Field(
|
||||
...,
|
||||
description="Position as a fraction [0.0, 1.0] of the total output video duration.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2PromptImage(BaseModel):
|
||||
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2ContentModeration(BaseModel):
|
||||
publicFigureThreshold: str = Field(
|
||||
...,
|
||||
description='When set to "low", the content moderation system is less strict about '
|
||||
'recognizable public figures. One of "auto" or "low".',
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2Request(BaseModel):
|
||||
model: str = Field(default="aleph2")
|
||||
promptText: str = Field(
|
||||
...,
|
||||
description="A non-empty string describing what should appear in the output.",
|
||||
min_length=1,
|
||||
max_length=1000,
|
||||
)
|
||||
videoUri: str = Field(...)
|
||||
seed: int = Field(..., description="Random seed for generation", ge=0, le=4294967295)
|
||||
contentModeration: RunwayAleph2ContentModeration = Field(...)
|
||||
keyframes: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] | None = Field(
|
||||
None,
|
||||
description="Timed guidance images placed at specific points in the input video. Up to 5.",
|
||||
)
|
||||
promptImage: list[RunwayAleph2PromptImage] | None = Field(
|
||||
None,
|
||||
description="Up to 5 image keyframes for guiding the edit at specific points in the output video.",
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2Response(BaseModel):
|
||||
id: str | None = Field(None, description="Task ID")
|
||||
|
||||
@ -534,6 +534,15 @@ class FluxEraseNode(IO.ComfyNode):
|
||||
max=25,
|
||||
tooltip="Expands the mask boundaries to ensure clean coverage of the object's edges.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="The random seed used for creating the noise.",
|
||||
optional=True,
|
||||
),
|
||||
],
|
||||
outputs=[IO.Image.Output()],
|
||||
hidden=[
|
||||
@ -553,6 +562,7 @@ class FluxEraseNode(IO.ComfyNode):
|
||||
image: Input.Image,
|
||||
mask: Input.Image,
|
||||
dilate_pixels: int = 10,
|
||||
seed: int = 0,
|
||||
) -> IO.NodeOutput:
|
||||
validate_image_dimensions(image, min_width=256, min_height=256)
|
||||
mask = resize_mask_to_image(mask, image)
|
||||
@ -565,6 +575,7 @@ class FluxEraseNode(IO.ComfyNode):
|
||||
image=tensor_to_base64_string(image[:, :, :, :3]), # make sure image will have alpha channel removed
|
||||
mask=mask,
|
||||
dilate_pixels=dilate_pixels,
|
||||
seed=seed,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@ -12,6 +12,8 @@ from comfy_api_nodes.apis.bria import (
|
||||
BriaRemoveVideoBackgroundRequest,
|
||||
BriaRemoveVideoBackgroundResponse,
|
||||
BriaStatusResponse,
|
||||
BriaVideoGreenScreenRequest,
|
||||
BriaVideoReplaceBackgroundRequest,
|
||||
InputModerationSettings,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
@ -287,7 +289,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -319,6 +321,161 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
class BriaVideoGreenScreen(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaVideoGreenScreen",
|
||||
display_name="Bria Video Green Screen",
|
||||
category="partner/video/Bria",
|
||||
description="Replace a video's background with a solid chroma-key screen using Bria.",
|
||||
inputs=[
|
||||
IO.Video.Input("video"),
|
||||
IO.Combo.Input(
|
||||
"green_shade",
|
||||
options=["broadcast_green", "chroma_green", "blue_screen"],
|
||||
tooltip="Solid chroma-key shade applied behind the foreground: "
|
||||
"broadcast_green (#00B140), chroma_green (#00FF00), or blue_screen (#0000FF).",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
green_shade: str,
|
||||
seed: int,
|
||||
) -> IO.NodeOutput:
|
||||
validate_video_duration(video, max_duration=60.0)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/green_screen", method="POST"),
|
||||
data=BriaVideoGreenScreenRequest(
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
green_shade=green_shade,
|
||||
output_container_and_codec="mp4_h264",
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveVideoBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
class BriaVideoReplaceBackground(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="BriaVideoReplaceBackground",
|
||||
display_name="Bria Video Replace Background",
|
||||
category="partner/video/Bria",
|
||||
description="Replace a video's background with a supplied image or video using Bria. "
|
||||
"The output keeps the foreground's resolution and frame rate; a background with a "
|
||||
"different aspect ratio is stretched to fit, so match it for undistorted results.",
|
||||
inputs=[
|
||||
IO.Video.Input("video", tooltip="Foreground video whose background is replaced."),
|
||||
IO.Image.Input(
|
||||
"background_image",
|
||||
optional=True,
|
||||
tooltip="Background image to composite behind the foreground. "
|
||||
"Provide either a background image or a background video, not both.",
|
||||
),
|
||||
IO.Video.Input(
|
||||
"background_video",
|
||||
optional=True,
|
||||
tooltip="Background video to composite behind the foreground. "
|
||||
"Provide either a background image or a background video, not both.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed controls whether the node should re-run; "
|
||||
"results are non-deterministic regardless of seed.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Video.Output()],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
video: Input.Video,
|
||||
seed: int,
|
||||
background_image: Input.Image | None = None,
|
||||
background_video: Input.Video | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
if (background_image is None) == (background_video is None):
|
||||
raise ValueError("Provide either a background image or a background video, not both.")
|
||||
validate_video_duration(video, max_duration=60.0)
|
||||
if background_video is not None:
|
||||
validate_video_duration(background_video, max_duration=60.0)
|
||||
background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background")
|
||||
else:
|
||||
# Bria's replace_background 500s on RGBA, so drop the alpha channel before upload.
|
||||
background_url = await upload_image_to_comfyapi(
|
||||
cls, background_image[:, :, :, :3], wait_label="Uploading background"
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"),
|
||||
data=BriaVideoReplaceBackgroundRequest(
|
||||
video=await upload_video_to_comfyapi(cls, video),
|
||||
background_url=background_url,
|
||||
output_container_and_codec="mp4_h264",
|
||||
seed=seed,
|
||||
),
|
||||
response_model=BriaStatusResponse,
|
||||
)
|
||||
response = await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"/proxy/bria/v2/status/{response.request_id}"),
|
||||
status_extractor=lambda r: r.status,
|
||||
response_model=BriaRemoveVideoBackgroundResponse,
|
||||
)
|
||||
return IO.NodeOutput(await download_url_to_video_output(response.result.video_url))
|
||||
|
||||
|
||||
def _video_to_images_and_mask(video: Input.Video) -> tuple[Input.Image, Input.Mask]:
|
||||
"""Decode a transparent webm (VP9 + alpha) into image frames and an alpha mask.
|
||||
|
||||
@ -376,7 +533,7 @@ class BriaTransparentVideoBackground(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -416,6 +573,8 @@ class BriaExtension(ComfyExtension):
|
||||
BriaImageEditNode,
|
||||
BriaRemoveImageBackground,
|
||||
BriaRemoveVideoBackground,
|
||||
BriaVideoGreenScreen,
|
||||
BriaVideoReplaceBackground,
|
||||
BriaTransparentVideoBackground,
|
||||
]
|
||||
|
||||
|
||||
@ -7,6 +7,7 @@ from io import BytesIO
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.bytedance import (
|
||||
RECOMMENDED_PRESETS,
|
||||
@ -131,6 +132,44 @@ def _prepare_seedance_image(image: Input.Image) -> Input.Image:
|
||||
return image
|
||||
|
||||
|
||||
# Supported output aspect ratios, used to pre-size FLF frames to matching pixel pair to avoid the 1080p stretch jump.
|
||||
SEEDANCE2_RATIO_WH = {
|
||||
"16:9": (16, 9),
|
||||
"4:3": (4, 3),
|
||||
"1:1": (1, 1),
|
||||
"3:4": (3, 4),
|
||||
"9:16": (9, 16),
|
||||
"21:9": (21, 9),
|
||||
}
|
||||
SEEDANCE2_RES_SHORT_SIDE = {"480p": 480, "720p": 720, "1080p": 1080}
|
||||
|
||||
|
||||
def _seedance2_target_dims(resolution: str, ratio: str, image: torch.Tensor) -> tuple[int, int]:
|
||||
"""Exact supported output (width, height) for (resolution, ratio).
|
||||
|
||||
The shorter side equals the resolution number (e.g. 1080p 16:9 -> 1920x1080). For ratio
|
||||
"adaptive" (or any unexpected value) the ratio is derived from the image's own aspect, snapped
|
||||
to the nearest supported ratio, so the output keeps the frame's orientation.
|
||||
"""
|
||||
short = SEEDANCE2_RES_SHORT_SIDE[resolution]
|
||||
if ratio not in SEEDANCE2_RATIO_WH:
|
||||
aspect = image.shape[-2] / image.shape[-3] # W / H; tensor is (B, H, W, C)
|
||||
ratio = min(SEEDANCE2_RATIO_WH, key=lambda k: abs(SEEDANCE2_RATIO_WH[k][0] / SEEDANCE2_RATIO_WH[k][1] - aspect))
|
||||
rw, rh = SEEDANCE2_RATIO_WH[ratio]
|
||||
if rw >= rh: # landscape or square: shorter side is the height
|
||||
out_w, out_h = round(short * rw / rh), short
|
||||
else: # portrait: shorter side is the width
|
||||
out_w, out_h = short, round(short * rh / rw)
|
||||
return out_w - out_w % 2, out_h - out_h % 2
|
||||
|
||||
|
||||
def _resize_to_exact(image: torch.Tensor, width: int, height: int) -> torch.Tensor:
|
||||
"""Center-crop to the target aspect and resize to exactly width x height (lanczos)."""
|
||||
samples = image.movedim(-1, 1) # (B, H, W, C) -> (B, C, H, W)
|
||||
resized = common_upscale(samples, width, height, "lanczos", "center")
|
||||
return resized.movedim(1, -1)
|
||||
|
||||
|
||||
async def _resolve_reference_assets(
|
||||
cls: type[IO.ComfyNode],
|
||||
asset_ids: list[str],
|
||||
@ -1790,10 +1829,28 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
if last_frame is not None and last_frame_asset_id:
|
||||
raise ValueError("Provide only one of last_frame or last_frame_asset_id, not both.")
|
||||
|
||||
if first_frame is not None:
|
||||
first_frame = _prepare_seedance_image(first_frame)
|
||||
if last_frame is not None:
|
||||
last_frame = _prepare_seedance_image(last_frame)
|
||||
request_ratio = model["ratio"]
|
||||
if first_frame_asset_id or last_frame_asset_id:
|
||||
if first_frame is not None:
|
||||
first_frame = _prepare_seedance_image(first_frame)
|
||||
if last_frame is not None:
|
||||
last_frame = _prepare_seedance_image(last_frame)
|
||||
else:
|
||||
# The 1080p FLF stretch fix (pre-size frames to a supported pixel pair + submit ratio="adaptive")
|
||||
# only applies to local image inputs we can resize.
|
||||
request_ratio = "adaptive"
|
||||
target_dims: tuple[int, int] | None = None
|
||||
if first_frame is not None:
|
||||
validate_image_aspect_ratio(first_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
validate_image_dimensions(first_frame, min_width=300, min_height=300)
|
||||
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], first_frame)
|
||||
first_frame = _resize_to_exact(first_frame, *target_dims)
|
||||
if last_frame is not None:
|
||||
validate_image_aspect_ratio(last_frame, (2, 5), (5, 2), strict=False) # 0.4 to 2.5
|
||||
validate_image_dimensions(last_frame, min_width=300, min_height=300)
|
||||
if target_dims is None:
|
||||
target_dims = _seedance2_target_dims(model["resolution"], model["ratio"], last_frame)
|
||||
last_frame = _resize_to_exact(last_frame, *target_dims)
|
||||
|
||||
asset_ids_to_resolve = [a for a in (first_frame_asset_id, last_frame_asset_id) if a]
|
||||
image_assets: dict[str, str] = {}
|
||||
@ -1844,7 +1901,7 @@ class ByteDance2FirstLastFrameNode(IO.ComfyNode):
|
||||
content=content,
|
||||
generate_audio=model["generate_audio"],
|
||||
resolution=model["resolution"],
|
||||
ratio=model["ratio"],
|
||||
ratio=request_ratio,
|
||||
duration=model["duration"],
|
||||
seed=seed,
|
||||
watermark=watermark,
|
||||
|
||||
@ -8,7 +8,7 @@ import os
|
||||
from enum import Enum
|
||||
from fnmatch import fnmatch
|
||||
from io import BytesIO
|
||||
from typing import Literal
|
||||
from typing import Any, Literal
|
||||
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
@ -19,6 +19,7 @@ from comfy_api_nodes.apis.gemini import (
|
||||
GeminiContent,
|
||||
GeminiFileData,
|
||||
GeminiGenerateContentRequest,
|
||||
GeminiGenerationConfig,
|
||||
GeminiGenerateContentResponse,
|
||||
GeminiImageConfig,
|
||||
GeminiImageGenerateContentRequest,
|
||||
@ -40,13 +41,18 @@ from comfy_api_nodes.util import (
|
||||
get_number_of_images,
|
||||
sync_op,
|
||||
tensor_to_base64_string,
|
||||
upload_audio_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
validate_string,
|
||||
video_to_base64_string,
|
||||
)
|
||||
|
||||
GEMINI_BASE_ENDPOINT = "/proxy/vertexai/gemini"
|
||||
GEMINI_MAX_INPUT_FILE_SIZE = 20 * 1024 * 1024 # 20 MB
|
||||
GEMINI_URL_INPUT_BUDGET = 10
|
||||
GEMINI_MAX_INLINE_BYTES = 18 * 1024 * 1024
|
||||
GEMINI_IMAGE_SYS_PROMPT = (
|
||||
"You are an expert image-generation engine. You must ALWAYS produce an image.\n"
|
||||
"Interpret all user input—regardless of "
|
||||
@ -285,6 +291,140 @@ def calculate_tokens_price(response: GeminiGenerateContentResponse) -> float | N
|
||||
return final_price / 1_000_000.0
|
||||
|
||||
|
||||
def create_video_parts(video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert a single video input to Gemini API compatible parts (inline MP4/H.264)."""
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.video_mp4,
|
||||
data=base_64_string,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def create_audio_parts(audio_input: Input.Audio) -> list[GeminiPart]:
|
||||
"""Convert an audio input to Gemini API compatible parts (one inline MP3 part per batch item)."""
|
||||
audio_parts: list[GeminiPart] = []
|
||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||
audio_at_index = Input.Audio(
|
||||
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
sample_rate=audio_input["sample_rate"],
|
||||
)
|
||||
# Convert to MP3 format for compatibility with Gemini API
|
||||
audio_bytes = audio_to_base64_string(
|
||||
audio_at_index,
|
||||
container_format="mp3",
|
||||
codec_name="libmp3lame",
|
||||
)
|
||||
audio_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.audio_mp3,
|
||||
data=audio_bytes,
|
||||
)
|
||||
)
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
|
||||
def _flatten_images(images: list[Input.Image]) -> list[torch.Tensor]:
|
||||
"""Expand any batched image tensors into individual (H, W, C) frames, preserving order."""
|
||||
frames: list[torch.Tensor] = []
|
||||
for img in images:
|
||||
if len(img.shape) == 4:
|
||||
frames.extend(img[i] for i in range(img.shape[0]))
|
||||
else:
|
||||
frames.append(img)
|
||||
return frames
|
||||
|
||||
|
||||
def _flatten_audio(audios: list[Input.Audio]) -> list[Input.Audio]:
|
||||
"""Expand any batched audio inputs into individual single-clip audio inputs, preserving order."""
|
||||
clips: list[Input.Audio] = []
|
||||
for audio in audios:
|
||||
waveform = audio["waveform"]
|
||||
for i in range(waveform.shape[0]):
|
||||
clips.append(Input.Audio(waveform=waveform[i].unsqueeze(0), sample_rate=audio["sample_rate"]))
|
||||
return clips
|
||||
|
||||
|
||||
async def _media_url_part(cls: type[IO.ComfyNode], kind: str, payload: Any) -> GeminiPart:
|
||||
"""Upload a single media unit to ComfyAPI storage and return a fileData (URL) part."""
|
||||
if kind == "image":
|
||||
url = await upload_image_to_comfyapi(cls, payload, mime_type="image/png", wait_label="Uploading image")
|
||||
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.image_png, fileUri=url))
|
||||
if kind == "audio":
|
||||
url = await upload_audio_to_comfyapi(
|
||||
cls, payload, container_format="mp3", codec_name="libmp3lame", mime_type="audio/mp3"
|
||||
)
|
||||
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.audio_mp3, fileUri=url))
|
||||
url = await upload_video_to_comfyapi(cls, payload, wait_label="Uploading video")
|
||||
return GeminiPart(fileData=GeminiFileData(mimeType=GeminiMimeType.video_mp4, fileUri=url))
|
||||
|
||||
|
||||
def _media_inline_part(kind: str, payload: Any) -> tuple[GeminiPart, int]:
|
||||
"""Encode a single media unit as an inline base64 part; returns (part, base64_length)."""
|
||||
if kind == "image":
|
||||
data = tensor_to_base64_string(payload, mime_type="image/webp")
|
||||
mime = GeminiMimeType.image_webp
|
||||
elif kind == "audio":
|
||||
data = audio_to_base64_string(payload, container_format="mp3", codec_name="libmp3lame")
|
||||
mime = GeminiMimeType.audio_mp3
|
||||
else:
|
||||
data = video_to_base64_string(
|
||||
payload, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
mime = GeminiMimeType.video_mp4
|
||||
return GeminiPart(inlineData=GeminiInlineData(mimeType=mime, data=data)), len(data)
|
||||
|
||||
|
||||
async def build_gemini_media_parts(
|
||||
cls: type[IO.ComfyNode],
|
||||
images: list[Input.Image],
|
||||
audios: list[Input.Audio],
|
||||
videos: list[Input.Video],
|
||||
*,
|
||||
url_budget: int = GEMINI_URL_INPUT_BUDGET,
|
||||
max_inline_bytes: int = GEMINI_MAX_INLINE_BYTES,
|
||||
) -> list[GeminiPart]:
|
||||
"""Build Gemini parts for multimodal inputs (images, audio, video).
|
||||
|
||||
fileData URLs are preferred for every media type: the upload is fetched directly by the
|
||||
model, keeping the request body tiny regardless of media size. The URL budget is shared
|
||||
across all media and assigned largest-first (video, then audio, then images), so that if it
|
||||
is ever exhausted the inline-base64 overflow is limited to the smallest items. Total inline
|
||||
payload is capped by `max_inline_bytes`.
|
||||
"""
|
||||
units: list[tuple[str, Any]] = (
|
||||
[("video", v) for v in videos]
|
||||
+ [("audio", a) for a in _flatten_audio(audios)]
|
||||
+ [("image", f) for f in _flatten_images(images)]
|
||||
)
|
||||
|
||||
parts: list[GeminiPart] = []
|
||||
url_used = 0
|
||||
inline_bytes = 0
|
||||
for kind, payload in units:
|
||||
if url_used < url_budget:
|
||||
parts.append(await _media_url_part(cls, kind, payload))
|
||||
url_used += 1
|
||||
continue
|
||||
part, nbytes = _media_inline_part(kind, payload)
|
||||
inline_bytes += nbytes
|
||||
if inline_bytes > max_inline_bytes:
|
||||
raise ValueError(
|
||||
f"Too much media to send inline (over {max_inline_bytes // (1024 * 1024)}MB after the first "
|
||||
f"{url_budget} inputs are uploaded as URLs). Reduce the number or size of attached media."
|
||||
)
|
||||
parts.append(part)
|
||||
return parts
|
||||
|
||||
|
||||
class GeminiNode(IO.ComfyNode):
|
||||
"""
|
||||
Node to generate text responses from a Gemini model.
|
||||
@ -407,58 +547,9 @@ class GeminiNode(IO.ComfyNode):
|
||||
)
|
||||
""",
|
||||
),
|
||||
is_deprecated=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def create_video_parts(cls, video_input: Input.Video) -> list[GeminiPart]:
|
||||
"""Convert video input to Gemini API compatible parts."""
|
||||
|
||||
base_64_string = video_to_base64_string(
|
||||
video_input, container_format=Types.VideoContainer.MP4, codec=Types.VideoCodec.H264
|
||||
)
|
||||
return [
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.video_mp4,
|
||||
data=base_64_string,
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def create_audio_parts(cls, audio_input: Input.Audio) -> list[GeminiPart]:
|
||||
"""
|
||||
Convert audio input to Gemini API compatible parts.
|
||||
|
||||
Args:
|
||||
audio_input: Audio input from ComfyUI, containing waveform tensor and sample rate.
|
||||
|
||||
Returns:
|
||||
List of GeminiPart objects containing the encoded audio.
|
||||
"""
|
||||
audio_parts: list[GeminiPart] = []
|
||||
for batch_index in range(audio_input["waveform"].shape[0]):
|
||||
# Recreate an IO.AUDIO object for the given batch dimension index
|
||||
audio_at_index = Input.Audio(
|
||||
waveform=audio_input["waveform"][batch_index].unsqueeze(0),
|
||||
sample_rate=audio_input["sample_rate"],
|
||||
)
|
||||
# Convert to MP3 format for compatibility with Gemini API
|
||||
audio_bytes = audio_to_base64_string(
|
||||
audio_at_index,
|
||||
container_format="mp3",
|
||||
codec_name="libmp3lame",
|
||||
)
|
||||
audio_parts.append(
|
||||
GeminiPart(
|
||||
inlineData=GeminiInlineData(
|
||||
mimeType=GeminiMimeType.audio_mp3,
|
||||
data=audio_bytes,
|
||||
)
|
||||
)
|
||||
)
|
||||
return audio_parts
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
@ -482,9 +573,9 @@ class GeminiNode(IO.ComfyNode):
|
||||
if images is not None:
|
||||
parts.extend(await create_image_parts(cls, images))
|
||||
if audio is not None:
|
||||
parts.extend(cls.create_audio_parts(audio))
|
||||
parts.extend(create_audio_parts(audio))
|
||||
if video is not None:
|
||||
parts.extend(cls.create_video_parts(video))
|
||||
parts.extend(create_video_parts(video))
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
@ -512,6 +603,210 @@ class GeminiNode(IO.ComfyNode):
|
||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||
|
||||
|
||||
GEMINI_V2_MODELS: dict[str, str] = {
|
||||
"Gemini 3.1 Pro": "gemini-3.1-pro-preview",
|
||||
"Gemini 3.1 Flash-Lite": "gemini-3.1-flash-lite-preview",
|
||||
}
|
||||
|
||||
|
||||
def _gemini_text_model_inputs(thinking_default: str) -> list[Input]:
|
||||
"""Per-model inputs revealed by the model DynamicCombo (shared media + sampling controls)."""
|
||||
return [
|
||||
IO.Autogrow.Input(
|
||||
"images",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Image.Input("image"),
|
||||
names=[f"image_{i}" for i in range(1, 17)],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional image(s) to use as context for the model. Up to 16 images.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"audio",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Audio.Input("audio"),
|
||||
names=["audio_1"],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional audio clip to use as context for the model.",
|
||||
),
|
||||
IO.Autogrow.Input(
|
||||
"video",
|
||||
template=IO.Autogrow.TemplateNames(
|
||||
IO.Video.Input("video"),
|
||||
names=["video_1"],
|
||||
min=0,
|
||||
),
|
||||
tooltip="Optional video clip to use as context for the model.",
|
||||
),
|
||||
IO.Custom("GEMINI_INPUT_FILES").Input(
|
||||
"files",
|
||||
optional=True,
|
||||
tooltip="Optional file(s) to use as context for the model. "
|
||||
"Accepts inputs from the Gemini Input Files node.",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"thinking_level",
|
||||
options=["LOW", "HIGH"],
|
||||
default=thinking_default,
|
||||
tooltip="How hard the model reasons internally before answering. "
|
||||
"HIGH improves quality on difficult tasks but costs more (thinking) tokens and is slower.",
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
tooltip="Controls randomness. Lower is more focused/deterministic, higher is more creative.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=0.95,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
tooltip="Nucleus sampling: sample from the smallest token set whose cumulative probability reaches top_p.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Int.Input(
|
||||
"max_output_tokens",
|
||||
default=32768,
|
||||
min=16,
|
||||
max=65536,
|
||||
tooltip="Maximum tokens to generate, including the model's internal thinking. "
|
||||
"With thinking_level HIGH, a low value can leave no room for the answer; raise this if "
|
||||
"responses come back empty or truncated. The model stops early when finished, so a higher "
|
||||
"cap costs nothing extra for short replies.",
|
||||
advanced=True,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
class GeminiNodeV2(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="GeminiNodeV2",
|
||||
display_name="Google Gemini",
|
||||
category="partner/text/Gemini",
|
||||
essentials_category="Text Generation",
|
||||
description="Generate text responses with Google's Gemini models. Provide a text prompt and, "
|
||||
"optionally, one or more images, audio clips, videos, or files as multimodal context.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Text input to the model. Include detailed instructions, questions, or context.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("Gemini 3.1 Pro", _gemini_text_model_inputs("HIGH")),
|
||||
IO.DynamicCombo.Option("Gemini 3.1 Flash-Lite", _gemini_text_model_inputs("LOW")),
|
||||
],
|
||||
tooltip="The Gemini model used to generate the response.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=42,
|
||||
min=0,
|
||||
max=2147483647,
|
||||
control_after_generate=True,
|
||||
tooltip="Seed for sampling. Set to 0 for a random seed. Deterministic output isn't guaranteed.",
|
||||
),
|
||||
IO.String.Input(
|
||||
"system_prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
optional=True,
|
||||
advanced=True,
|
||||
tooltip="Foundational instructions that dictate the model's behavior.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
depends_on=IO.PriceBadgeDepends(widgets=["model"]),
|
||||
expr="""
|
||||
(
|
||||
$m := widgets.model;
|
||||
$contains($m, "lite") ? {
|
||||
"type": "list_usd",
|
||||
"usd": [0.00025, 0.0015],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
} : {
|
||||
"type": "list_usd",
|
||||
"usd": [0.002, 0.012],
|
||||
"format": { "approximate": true, "separator": "-", "suffix": " per 1K tokens" }
|
||||
}
|
||||
)
|
||||
""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
model: dict,
|
||||
seed: int,
|
||||
system_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_id = GEMINI_V2_MODELS[model["model"]]
|
||||
|
||||
parts: list[GeminiPart] = [GeminiPart(text=prompt)]
|
||||
images = [t for t in (model.get("images") or {}).values() if t is not None]
|
||||
audios = [a for a in (model.get("audio") or {}).values() if a is not None]
|
||||
videos = [v for v in (model.get("video") or {}).values() if v is not None]
|
||||
if images or audios or videos:
|
||||
parts.extend(await build_gemini_media_parts(cls, images, audios, videos))
|
||||
files = model.get("files")
|
||||
if files is not None:
|
||||
parts.extend(files)
|
||||
|
||||
gemini_system_prompt = None
|
||||
if system_prompt:
|
||||
gemini_system_prompt = GeminiSystemInstructionContent(parts=[GeminiTextPart(text=system_prompt)], role=None)
|
||||
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=f"{GEMINI_BASE_ENDPOINT}/{model_id}", method="POST"),
|
||||
data=GeminiGenerateContentRequest(
|
||||
contents=[
|
||||
GeminiContent(
|
||||
role=GeminiRole.user,
|
||||
parts=parts,
|
||||
)
|
||||
],
|
||||
generationConfig=GeminiGenerationConfig(
|
||||
temperature=model["temperature"],
|
||||
topP=model["top_p"],
|
||||
maxOutputTokens=model["max_output_tokens"],
|
||||
seed=seed if seed > 0 else None,
|
||||
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
response_model=GeminiGenerateContentResponse,
|
||||
price_extractor=calculate_tokens_price,
|
||||
)
|
||||
|
||||
output_text = get_text_from_response(response)
|
||||
return IO.NodeOutput(output_text or "Empty response from Gemini model...")
|
||||
|
||||
|
||||
class GeminiInputFiles(IO.ComfyNode):
|
||||
"""
|
||||
Loads and formats input files for use with the Gemini API.
|
||||
@ -1129,6 +1424,26 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
tooltip="Foundational instructions that dictate an AI's behavior.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"temperature",
|
||||
default=1.0,
|
||||
min=0.0,
|
||||
max=2.0,
|
||||
step=0.01,
|
||||
optional=True,
|
||||
tooltip="Controls randomness in generation. Lower is more focused/deterministic.",
|
||||
advanced=True,
|
||||
),
|
||||
IO.Float.Input(
|
||||
"top_p",
|
||||
default=0.95,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
optional=True,
|
||||
tooltip="Nucleus sampling threshold. Lower is more focused, higher more diverse.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Image.Output(),
|
||||
@ -1165,6 +1480,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
seed: int,
|
||||
response_modalities: str,
|
||||
system_prompt: str = "",
|
||||
temperature: float = 1.0,
|
||||
top_p: float = 0.95,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, strip_whitespace=True, min_length=1)
|
||||
model_choice = model["model"]
|
||||
@ -1204,6 +1521,8 @@ class GeminiNanoBanana2V2(IO.ComfyNode):
|
||||
responseModalities=(["IMAGE"] if response_modalities == "IMAGE" else ["TEXT", "IMAGE"]),
|
||||
imageConfig=image_config,
|
||||
thinkingConfig=GeminiThinkingConfig(thinkingLevel=model["thinking_level"]),
|
||||
temperature=temperature,
|
||||
topP=top_p,
|
||||
),
|
||||
systemInstruction=gemini_system_prompt,
|
||||
),
|
||||
@ -1222,6 +1541,7 @@ class GeminiExtension(ComfyExtension):
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
return [
|
||||
GeminiNode,
|
||||
GeminiNodeV2,
|
||||
GeminiImage,
|
||||
GeminiImage2,
|
||||
GeminiNanoBanana2,
|
||||
|
||||
@ -436,7 +436,7 @@ async def execute_text2video(
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
duration=KlingVideoGenDuration(duration),
|
||||
mode=KlingVideoGenMode(model_mode),
|
||||
model_name=KlingVideoGenModelName(model_name),
|
||||
model_name=model_name,
|
||||
cfg_scale=cfg_scale,
|
||||
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
|
||||
camera_control=camera_control,
|
||||
|
||||
@ -42,9 +42,11 @@ async def _upload_image_to_krea_assets(cls: type[IO.ComfyNode], image: Input.Ima
|
||||
|
||||
|
||||
_MODEL_MEDIUM = "Krea 2 Medium"
|
||||
_MODEL_MEDIUM_TURBO = "Krea 2 Medium Turbo"
|
||||
_MODEL_LARGE = "Krea 2 Large"
|
||||
_MODEL_ENDPOINTS: dict[str, str] = {
|
||||
_MODEL_MEDIUM: "/proxy/krea/generate/image/krea/krea-2/medium",
|
||||
_MODEL_MEDIUM_TURBO: "/proxy/krea/generate/image/krea/krea-2/medium-turbo",
|
||||
_MODEL_LARGE: "/proxy/krea/generate/image/krea/krea-2/large",
|
||||
}
|
||||
|
||||
@ -57,7 +59,7 @@ _UUID_RE = re.compile(r"^[0-9a-fA-F]{8}-[0-9a-fA-F]{4}-[0-9a-fA-F]{4}-[0-9a-fA-F
|
||||
|
||||
|
||||
def _krea_model_inputs() -> list:
|
||||
"""Nested inputs shared by both Krea 2 Medium and Large under the DynamicCombo."""
|
||||
"""Nested inputs shared by Krea 2 Medium, Medium Turbo and Large under the DynamicCombo."""
|
||||
return [
|
||||
IO.Combo.Input(
|
||||
"aspect_ratio",
|
||||
@ -123,6 +125,7 @@ class Krea2ImageNode(IO.ComfyNode):
|
||||
"model",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(_MODEL_MEDIUM, _krea_model_inputs()),
|
||||
IO.DynamicCombo.Option(_MODEL_MEDIUM_TURBO, _krea_model_inputs()),
|
||||
IO.DynamicCombo.Option(_MODEL_LARGE, _krea_model_inputs()),
|
||||
],
|
||||
tooltip="Krea 2 Medium is best for expressive illustrations; "
|
||||
@ -151,14 +154,15 @@ class Krea2ImageNode(IO.ComfyNode):
|
||||
),
|
||||
expr="""
|
||||
(
|
||||
$isLarge := widgets.model = "krea 2 large";
|
||||
$rates := {
|
||||
"krea 2 medium turbo": {"text": 0.015, "style": 0.0175, "moodboard": 0.02},
|
||||
"krea 2 medium": {"text": 0.03, "style": 0.035, "moodboard": 0.04},
|
||||
"krea 2 large": {"text": 0.06, "style": 0.065, "moodboard": 0.07}
|
||||
};
|
||||
$r := $lookup($rates, widgets.model);
|
||||
$hasMoodboard := $length($lookup(widgets, "model.moodboard_id")) > 0;
|
||||
$hasStyle := $lookup(inputs, "model.style_reference").connected;
|
||||
$usd := $hasMoodboard
|
||||
? ($isLarge ? 0.07 : 0.04)
|
||||
: ($hasStyle
|
||||
? ($isLarge ? 0.065 : 0.035)
|
||||
: ($isLarge ? 0.06 : 0.03));
|
||||
$usd := $hasMoodboard ? $r.moodboard : ($hasStyle ? $r.style : $r.text);
|
||||
{"type":"usd","usd": $usd}
|
||||
)
|
||||
""",
|
||||
|
||||
@ -9,6 +9,7 @@ from PIL import Image
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.openai import (
|
||||
InputFileContent,
|
||||
@ -62,7 +63,8 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
A torch.Tensor of shape (N, H, W, C) with all returned images; images whose
|
||||
dimensions differ from the first image's are resized to match it.
|
||||
|
||||
Raises:
|
||||
ValueError: If the response is not valid.
|
||||
@ -89,6 +91,14 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
|
||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||
image_tensors.append(torch.from_numpy(arr))
|
||||
|
||||
# With size="auto" the API can return images whose dimensions differ by a few pixels within a single response
|
||||
# resize them to the first image's dimensions so they can be stacked into one batch.
|
||||
ref_h, ref_w = image_tensors[0].shape[:2]
|
||||
for i, t in enumerate(image_tensors):
|
||||
if t.shape[:2] != (ref_h, ref_w):
|
||||
samples = t.unsqueeze(0).movedim(-1, 1)
|
||||
samples = common_upscale(samples, ref_w, ref_h, "bilinear", "center")
|
||||
image_tensors[i] = samples.movedim(1, -1).squeeze(0)
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
|
||||
|
||||
@ -30,13 +30,33 @@ from comfy_api_nodes.apis.runway import (
|
||||
Model4,
|
||||
ReferenceImage,
|
||||
RunwayTextToImageAspectRatioEnum,
|
||||
RunwayAleph2IO,
|
||||
RunwayAleph2KeyframeChain,
|
||||
RunwayAleph2KeyframeItem,
|
||||
RunwayAleph2PromptImageChain,
|
||||
RunwayAleph2PromptImageItem,
|
||||
RunwayAleph2Request,
|
||||
RunwayAleph2Response,
|
||||
RunwayAleph2KeyframeSeconds,
|
||||
RunwayAleph2KeyframeAt,
|
||||
RunwayAleph2PromptImage,
|
||||
RunwayAleph2TimestampPosition,
|
||||
RunwayAleph2RelativePosition,
|
||||
RunwayAleph2ContentModeration,
|
||||
KEYFRAME_MODE_SECONDS,
|
||||
KEYFRAME_MODE_AT,
|
||||
PROMPT_IMAGE_MODE_TIMESTAMP,
|
||||
PROMPT_IMAGE_MODE_POSITION,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
validate_string,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
validate_video_duration,
|
||||
upload_images_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
download_url_to_image_tensor,
|
||||
ApiEndpoint,
|
||||
@ -45,6 +65,7 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_VIDEO_TO_VIDEO = "/proxy/runway/video_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
||||
|
||||
@ -53,12 +74,6 @@ AVERAGE_DURATION_FLF_SECONDS = 256
|
||||
AVERAGE_DURATION_T2I_SECONDS = 41
|
||||
|
||||
|
||||
class RunwayApiError(Exception):
|
||||
"""Base exception for Runway API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RunwayGen4TurboAspectRatio(str, Enum):
|
||||
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
||||
|
||||
@ -84,14 +99,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> float | None:
|
||||
if hasattr(response, "progress") and response.progress is not None:
|
||||
return response.progress * 100
|
||||
return None
|
||||
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
@ -102,14 +109,13 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
async def get_response(
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status.value,
|
||||
status_extractor=lambda r: r.status,
|
||||
estimated_duration=estimated_duration,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
progress_extractor=lambda r: r.progress * 100 if r.progress is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@ -127,7 +133,7 @@ async def generate_video(
|
||||
|
||||
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
||||
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return await download_url_to_video_output(video_url)
|
||||
@ -410,7 +416,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
mime_type="image/png",
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
raise ValueError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
@ -514,11 +520,321 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
||||
raise ValueError("Runway task succeeded but no image data found in response.")
|
||||
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
_TIMING_ABSOLUTE = "Absolute time (seconds)"
|
||||
_TIMING_FRACTION = "Fraction of duration (0.0-1.0)"
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2KeyframeNode",
|
||||
display_name="Runway Aleph2 Keyframe",
|
||||
category="partner/video/Runway",
|
||||
description="Anchor a guidance image to a moment of the input (source) video, so Aleph2 "
|
||||
"steers the edit at that point of your footage. Connect this to the 'keyframes' input of "
|
||||
"the Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||
"'keyframes' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The guidance image to apply at the chosen moment of the input video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"timing",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_ABSOLUTE,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=30.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from start of the input video where this image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_FRACTION,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the input video this image applies, "
|
||||
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the input video's timeline.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Optional earlier keyframes to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(RunwayAleph2IO.KEYFRAME).Output(display_name="keyframes")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
timing: dict,
|
||||
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = keyframes.clone() if keyframes is not None else RunwayAleph2KeyframeChain()
|
||||
if timing["timing"] == _TIMING_ABSOLUTE:
|
||||
mode, value = KEYFRAME_MODE_SECONDS, float(timing["seconds"])
|
||||
else:
|
||||
mode, value = KEYFRAME_MODE_AT, float(timing["fraction"])
|
||||
chain.add(RunwayAleph2KeyframeItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2PromptImageNode",
|
||||
display_name="Runway Aleph2 Prompt Image",
|
||||
category="partner/video/Runway",
|
||||
description="Anchor a guidance image to a moment of the output (result) video, to guide what "
|
||||
"the edited video looks like at that point. Connect this to the 'prompt_images' input of the "
|
||||
"Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||
"'prompt_images' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The guidance image to place at the chosen moment of the output video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"position",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_ABSOLUTE,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=30.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from start of the output video where this image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_FRACTION,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the output video this image applies, "
|
||||
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the output video's timeline.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||
"prompt_images",
|
||||
optional=True,
|
||||
tooltip="Optional earlier prompt images to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Output(display_name="prompt_images")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
position: dict,
|
||||
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = prompt_images.clone() if prompt_images is not None else RunwayAleph2PromptImageChain()
|
||||
if position["position"] == _TIMING_ABSOLUTE:
|
||||
mode, value = PROMPT_IMAGE_MODE_TIMESTAMP, float(position["seconds"])
|
||||
else:
|
||||
mode, value = PROMPT_IMAGE_MODE_POSITION, float(position["fraction"])
|
||||
chain.add(RunwayAleph2PromptImageItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class RunwayAleph2VideoToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2VideoToVideoNode",
|
||||
display_name="Runway Aleph2 Video to Video",
|
||||
category="partner/video/Runway",
|
||||
description="Edit a video with a text prompt using Runway's Aleph2 model. Aleph2 transforms "
|
||||
"your footage (restyle, relight, add or remove elements, change the viewpoint) while keeping "
|
||||
"the original motion and timing; the output resolution matches the input video, which must be "
|
||||
"2-30 seconds at 30 fps or lower. Optionally steer the edit with either keyframes (anchored to "
|
||||
"the input video) or prompt images (anchored to the output video) - use one or the other, not both.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describes what should appear in the output (1-1000 characters).",
|
||||
),
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
tooltip="Input video to edit. Must be 2-30 seconds at 30 fps or lower.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"public_figure_threshold",
|
||||
options=["auto", "low"],
|
||||
default="low",
|
||||
tooltip="Content moderation for recognizable public figures.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Guidance images anchored to the input video, from Aleph2 Keyframe nodes (up to 5). "
|
||||
"Use keyframes or prompt images, not both.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||
"prompt_images",
|
||||
optional=True,
|
||||
tooltip="Guidance images anchored to the output video, from Aleph2 Prompt Image nodes (up to 5). "
|
||||
"Use keyframes or prompt images, not both.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.4004, "format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
seed: int,
|
||||
public_figure_threshold: str = "low",
|
||||
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=1000)
|
||||
validate_video_duration(
|
||||
video,
|
||||
min_duration=2.0,
|
||||
max_duration=30.0,
|
||||
)
|
||||
try:
|
||||
fps = float(video.get_frame_rate())
|
||||
except Exception:
|
||||
fps = None
|
||||
if fps is not None and fps > 30.0 + 0.01:
|
||||
raise ValueError(f"Input video frame rate ({fps:.2f} fps) exceeds Aleph2's maximum of 30 fps.")
|
||||
|
||||
if (keyframes and keyframes.items) and (prompt_images and prompt_images.items):
|
||||
raise ValueError("Aleph2 accepts either keyframes or prompt images, not both.")
|
||||
|
||||
video_duration: float | None = None
|
||||
try:
|
||||
video_duration = video.get_duration()
|
||||
except Exception:
|
||||
video_duration = None
|
||||
|
||||
def _check_seconds(value: float, label: str) -> None:
|
||||
if video_duration is not None and value > video_duration + 0.0001:
|
||||
raise ValueError(f"{label} {value:.2f}s exceeds the input video duration ({video_duration:.2f}s).")
|
||||
|
||||
video_url = await upload_video_to_comfyapi(cls, video)
|
||||
|
||||
keyframe_models: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] = []
|
||||
if keyframes is not None:
|
||||
if len(keyframes.items) > 5:
|
||||
raise ValueError("Aleph2 supports at most 5 keyframes.")
|
||||
for item in keyframes.items:
|
||||
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||
if item.mode == KEYFRAME_MODE_SECONDS:
|
||||
_check_seconds(item.value, "Keyframe timestamp")
|
||||
keyframe_models.append(RunwayAleph2KeyframeSeconds(seconds=item.value, uri=image_url))
|
||||
else:
|
||||
keyframe_models.append(RunwayAleph2KeyframeAt(at=item.value, uri=image_url))
|
||||
|
||||
prompt_image_models: list[RunwayAleph2PromptImage] = []
|
||||
if prompt_images is not None:
|
||||
if len(prompt_images.items) > 5:
|
||||
raise ValueError("Aleph2 supports at most 5 prompt images.")
|
||||
for item in prompt_images.items:
|
||||
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||
if item.mode == PROMPT_IMAGE_MODE_TIMESTAMP:
|
||||
_check_seconds(item.value, "Prompt image timestamp")
|
||||
position = RunwayAleph2TimestampPosition(timestampSeconds=item.value)
|
||||
else:
|
||||
position = RunwayAleph2RelativePosition(positionPercentage=item.value)
|
||||
prompt_image_models.append(RunwayAleph2PromptImage(position=position, uri=image_url))
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_VIDEO_TO_VIDEO, method="POST"),
|
||||
response_model=RunwayAleph2Response,
|
||||
data=RunwayAleph2Request(
|
||||
promptText=prompt,
|
||||
videoUri=video_url,
|
||||
seed=seed,
|
||||
contentModeration=RunwayAleph2ContentModeration(publicFigureThreshold=public_figure_threshold),
|
||||
keyframes=keyframe_models or None,
|
||||
promptImage=prompt_image_models or None,
|
||||
),
|
||||
)
|
||||
|
||||
final_response = await get_response(cls, initial_response.id)
|
||||
if not final_response.output:
|
||||
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
class RunwayExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -527,6 +843,9 @@ class RunwayExtension(ComfyExtension):
|
||||
RunwayImageToVideoNodeGen3a,
|
||||
RunwayImageToVideoNodeGen4,
|
||||
RunwayTextToImageNode,
|
||||
RunwayAleph2VideoToVideoNode,
|
||||
RunwayAleph2KeyframeNode,
|
||||
RunwayAleph2PromptImageNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
from comfy_api_nodes.util._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_comfy_api_headers,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
)
|
||||
@ -174,8 +174,7 @@ async def _stream_sonilo_music(
|
||||
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
headers.update(get_auth_header(cls))
|
||||
headers = get_comfy_api_headers(cls)
|
||||
headers.update(endpoint.headers)
|
||||
|
||||
node_id = get_node_id(cls)
|
||||
|
||||
@ -9,6 +9,7 @@ from io import BytesIO
|
||||
from yarl import URL
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
from comfy.model_management import processing_interrupted
|
||||
from comfy_api.latest import IO
|
||||
|
||||
@ -35,6 +36,30 @@ def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
def get_usage_source(node_cls: type[IO.ComfyNode]) -> str:
|
||||
"""Source of the prompt that triggered this API node.
|
||||
|
||||
Defaults to "comfyui-api" when the submitting client didn't identify itself,
|
||||
i.e. a direct API call to this server.
|
||||
"""
|
||||
return node_cls.hidden.comfy_usage_source or "comfyui-api"
|
||||
|
||||
|
||||
def get_comfy_api_headers(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||
"""Common headers (auth, deploy environment, usage source) for Comfy API requests.
|
||||
|
||||
Centralizes the shared header set so every Comfy API request sends a consistent
|
||||
set and new shared headers only need to be added in one place. Intended for
|
||||
relative/cloud URLs resolved against ``default_base_url()``; because the result
|
||||
includes auth, callers must not attach it to arbitrary absolute/presigned URLs.
|
||||
"""
|
||||
return {
|
||||
**get_auth_header(node_cls),
|
||||
"Comfy-Env": get_deploy_environment(),
|
||||
"Comfy-Usage-Source": get_usage_source(node_cls),
|
||||
}
|
||||
|
||||
|
||||
def default_base_url() -> str:
|
||||
return getattr(args, "comfy_api_base", "https://api.comfy.org")
|
||||
|
||||
|
||||
@ -19,12 +19,10 @@ from comfy import utils
|
||||
from comfy_api.latest import IO
|
||||
from server import PromptServer
|
||||
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_comfy_api_headers,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
@ -645,8 +643,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
|
||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
payload_headers["Comfy-Env"] = get_deploy_environment()
|
||||
payload_headers.update(get_comfy_api_headers(cfg.node_cls))
|
||||
if cfg.endpoint.headers:
|
||||
payload_headers.update(cfg.endpoint.headers)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from folder_paths import get_output_directory
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_comfy_api_headers,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
to_aiohttp_url,
|
||||
@ -64,7 +64,7 @@ async def download_url_to_bytesio(
|
||||
if cls is None:
|
||||
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
headers = get_auth_header(cls)
|
||||
headers = get_comfy_api_headers(cls)
|
||||
|
||||
while True:
|
||||
attempt += 1
|
||||
|
||||
66
comfy_execution/asset_enrichment.py
Normal file
66
comfy_execution/asset_enrichment.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Enrich executed-node output entries with asset id."""
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
def enrich_output_with_assets(output_ui: dict) -> dict:
|
||||
"""Register file-type output entries as assets and inject their ``id``.
|
||||
|
||||
Runs at output-processing time, once per produced output, when
|
||||
--enable-assets is set. Returns a new dict; entries without a resolvable
|
||||
on-disk file path are left unchanged. Errors are caught per-entry so a
|
||||
failure never blocks execution or the other entries.
|
||||
"""
|
||||
from comfy.cli_args import args
|
||||
if not args.enable_assets:
|
||||
return output_ui
|
||||
|
||||
import folder_paths
|
||||
from app.assets.services.ingest import register_file_in_place, DependencyMissingError
|
||||
|
||||
enriched = {}
|
||||
for key, entries in output_ui.items():
|
||||
if not isinstance(entries, list):
|
||||
enriched[key] = entries
|
||||
continue
|
||||
new_entries = []
|
||||
for entry in entries:
|
||||
if not isinstance(entry, dict) or "filename" not in entry or "type" not in entry:
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
try:
|
||||
base = folder_paths.get_directory_by_type(entry["type"])
|
||||
if base is None:
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
base_abs = os.path.abspath(base)
|
||||
abs_path = os.path.abspath(os.path.join(base_abs, entry.get("subfolder") or "", entry["filename"]))
|
||||
try:
|
||||
if os.path.commonpath([base_abs, abs_path]) != base_abs:
|
||||
raise ValueError("escapes base")
|
||||
except ValueError:
|
||||
logging.warning("Asset enrichment skipped (path escapes base): %s", entry.get("filename"))
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
if not os.path.isfile(abs_path):
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
|
||||
# Register unconditionally: the file was just produced, and
|
||||
# register_file_in_place re-hashes so an overwritten path can
|
||||
# never carry a stale id.
|
||||
result = register_file_in_place(
|
||||
abs_path=abs_path,
|
||||
name=entry["filename"],
|
||||
tags=[entry["type"]],
|
||||
)
|
||||
|
||||
entry = dict(entry)
|
||||
entry["id"] = result.ref.id
|
||||
except DependencyMissingError:
|
||||
logging.warning("Asset enrichment skipped (blake3 not available): %s", entry.get("filename"))
|
||||
except Exception:
|
||||
logging.warning("Failed to enrich output entry with asset id: %s", entry.get("filename"), exc_info=True)
|
||||
new_entries.append(entry)
|
||||
enriched[key] = new_entries
|
||||
return enriched
|
||||
@ -3,6 +3,7 @@ Job utilities for the /api/jobs endpoint.
|
||||
Provides normalization and helper functions for job status tracking.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from comfy_api.internal import prune_dict
|
||||
@ -19,6 +20,25 @@ class JobStatus:
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
|
||||
|
||||
|
||||
def validate_job_id(value) -> str:
|
||||
"""Validate a client-supplied job (prompt) id.
|
||||
|
||||
Job ids must be UUIDs in the canonical lowercase hyphenated form. The id
|
||||
is stored and compared verbatim everywhere downstream — history keys,
|
||||
websocket events, and /interrupt matching — so accepting another spelling
|
||||
would silently rewrite the client's id and then miss every exact-match
|
||||
lookup. Rejecting loudly beats that.
|
||||
|
||||
Returns the id unchanged. Raises ValueError when the value is not a
|
||||
string in canonical UUID form.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"job id must be a string, got {type(value).__name__}")
|
||||
if str(uuid.UUID(value)) != value:
|
||||
raise ValueError("job id must be a UUID in canonical lowercase hyphenated form")
|
||||
return value
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
|
||||
|
||||
|
||||
@ -158,7 +158,7 @@ class SaveAudio(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudio",
|
||||
search_aliases=["export flac"],
|
||||
display_name="Save Audio (FLAC)",
|
||||
display_name="Save Audio (FLAC) (DEPRECATED)",
|
||||
category="audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
@ -166,6 +166,7 @@ class SaveAudio(IO.ComfyNode):
|
||||
IO.String.Input("filename_prefix", default="audio/ComfyUI"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_deprecated=True,
|
||||
is_output_node=True,
|
||||
outputs=[IO.Audio.Output("audio")]
|
||||
)
|
||||
@ -186,7 +187,7 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioMP3",
|
||||
search_aliases=["export mp3"],
|
||||
display_name="Save Audio (MP3)",
|
||||
display_name="Save Audio (MP3) (DEPRECATED)",
|
||||
category="audio",
|
||||
essentials_category="Audio",
|
||||
inputs=[
|
||||
@ -195,6 +196,7 @@ class SaveAudioMP3(IO.ComfyNode):
|
||||
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_deprecated=True,
|
||||
is_output_node=True,
|
||||
outputs=[IO.Audio.Output("audio")]
|
||||
)
|
||||
@ -217,7 +219,7 @@ class SaveAudioOpus(IO.ComfyNode):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioOpus",
|
||||
search_aliases=["export opus"],
|
||||
display_name="Save Audio (Opus)",
|
||||
display_name="Save Audio (Opus) (DEPRECATED)",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio"),
|
||||
@ -225,6 +227,7 @@ class SaveAudioOpus(IO.ComfyNode):
|
||||
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_deprecated=True,
|
||||
is_output_node=True,
|
||||
outputs=[IO.Audio.Output("audio")]
|
||||
)
|
||||
@ -241,6 +244,54 @@ class SaveAudioOpus(IO.ComfyNode):
|
||||
)
|
||||
|
||||
|
||||
class SaveAudioAdvanced(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="SaveAudioAdvanced",
|
||||
search_aliases=["save audio", "export audio", "output audio", "write audio", "flac", "mp3", "opus"],
|
||||
display_name="Save Audio (Advanced)",
|
||||
description="Saves the input audio to your ComfyUI output directory.",
|
||||
category="audio",
|
||||
inputs=[
|
||||
IO.Audio.Input("audio", tooltip="The audio to save."),
|
||||
IO.String.Input(
|
||||
"filename_prefix",
|
||||
default="audio/ComfyUI",
|
||||
tooltip=(
|
||||
"The prefix for the file to save. May include formatting tokens "
|
||||
"such as %date:yyyy-MM-dd%."
|
||||
),
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"format",
|
||||
options=[
|
||||
IO.DynamicCombo.Option("flac", []),
|
||||
IO.DynamicCombo.Option("mp3", [
|
||||
IO.Combo.Input("quality", options=["V0", "128k", "320k"], default="V0"),
|
||||
]),
|
||||
IO.DynamicCombo.Option("opus", [
|
||||
IO.Combo.Input("quality", options=["64k", "96k", "128k", "192k", "320k"], default="128k"),
|
||||
]),
|
||||
],
|
||||
tooltip="The file format in which to save the audio.",
|
||||
),
|
||||
],
|
||||
hidden=[IO.Hidden.prompt, IO.Hidden.extra_pnginfo],
|
||||
is_output_node=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, audio, filename_prefix: str, format: dict) -> IO.NodeOutput:
|
||||
file_format = format.get("format", None)
|
||||
quality = format.get("quality", None)
|
||||
if quality:
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format, quality=quality)
|
||||
else:
|
||||
ui=UI.AudioSaveHelper.get_save_audio_ui(audio, filename_prefix=filename_prefix, cls=cls, format=file_format)
|
||||
return IO.NodeOutput(ui=ui)
|
||||
|
||||
|
||||
class PreviewAudio(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
@ -823,6 +874,7 @@ class AudioExtension(ComfyExtension):
|
||||
SaveAudio,
|
||||
SaveAudioMP3,
|
||||
SaveAudioOpus,
|
||||
SaveAudioAdvanced,
|
||||
LoadAudio,
|
||||
PreviewAudio,
|
||||
ConditioningStableAudio,
|
||||
|
||||
115
comfy_extras/nodes_bernini.py
Normal file
115
comfy_extras/nodes_bernini.py
Normal file
@ -0,0 +1,115 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def _resize_long_edge(image, max_size, stride=16):
|
||||
"""Resize (preserve aspect) so the long edge <= max_size, then snap each side to `stride`"""
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
scale = min(max_size / max(h, w), 1.0)
|
||||
nh = max(stride, round(h * scale / stride) * stride)
|
||||
nw = max(stride, round(w * scale / stride) * stride)
|
||||
return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1)
|
||||
|
||||
|
||||
class BerniniConditioning(io.ComfyNode):
|
||||
"""Bernini in-context conditioning for a Wan2.2-A14B model.
|
||||
|
||||
Attaches the VAE-encoded source video / reference images to the conditioning
|
||||
source video first, then each reference image
|
||||
|
||||
The task is inferred from which inputs are connected:
|
||||
(nothing) -> t2v (text-to-video)
|
||||
source_video -> v2v (video-to-video)
|
||||
source_video + ref_images -> rv2v (reference-guided video editing)
|
||||
ref_images only -> r2v (reference-to-video)
|
||||
source_video + ref_video -> ads2v (insert image/video into video)
|
||||
|
||||
source_video is the edit base / canvas (resized to width x height).
|
||||
reference_video is moving content to composite in.
|
||||
Streams are ordered source_video, reference_video, then reference_images -> source_id (1, 2, 3, ...).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="BerniniConditioning",
|
||||
display_name="Bernini Conditioning",
|
||||
category="conditioning/video_models",
|
||||
description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video)."
|
||||
"Reference images injected as in-context tokens (r2v, rv2v) are encoded independently at their own native aspect ratio (long edge capped at ref_max_size)",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=8192, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("source_video", optional=True, tooltip=(
|
||||
"Source video to edit or restyle (v2v, rv2v). Resized to width/height and trimmed to length.")),
|
||||
io.Image.Input("reference_video", optional=True, tooltip=(
|
||||
"Video to insert into the source video (ads2v).")),
|
||||
io.Autogrow.Input("reference_images", optional=True,
|
||||
template=io.Autogrow.TemplatePrefix(
|
||||
input=io.Image.Input("reference_image", tooltip=(
|
||||
"Reference image injected as an in-context token (r2v, rv2v).")),
|
||||
prefix="reference_image_", min=0, max=8)),
|
||||
io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=(
|
||||
"Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px.")),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size,
|
||||
source_video=None, reference_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device())
|
||||
|
||||
# source_video (1), reference_video (2), reference_images (3, 4, ...).
|
||||
context = []
|
||||
if source_video is not None:
|
||||
vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
|
||||
context.append(vae.encode(vid[:, :, :, :3]))
|
||||
|
||||
if reference_video is not None:
|
||||
ref_vid = _resize_long_edge(reference_video[:length], ref_max_size) # moving content, native aspect
|
||||
context.append(vae.encode(ref_vid[:, :, :, :3]))
|
||||
|
||||
# reference_images is an autogrow dict {reference_image_0: IMAGE, ...}; each slot is a
|
||||
# separate stream at its own native aspect (a multi-image batch in one slot -> one stream per frame).
|
||||
if reference_images:
|
||||
for name in sorted(reference_images):
|
||||
imgs = reference_images[name]
|
||||
if imgs is None:
|
||||
continue
|
||||
for i in range(imgs.shape[0]):
|
||||
img = _resize_long_edge(imgs[i:i + 1], ref_max_size) # native aspect per ref
|
||||
context.append(vae.encode(img[:, :, :, :3]))
|
||||
|
||||
if context:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"context_latents": context})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"context_latents": context})
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent})
|
||||
|
||||
|
||||
class BerniniExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
BerniniConditioning,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BerniniExtension:
|
||||
return BerniniExtension()
|
||||
@ -36,15 +36,15 @@ class RemoveBackground(IO.ComfyNode):
|
||||
category="image/background removal",
|
||||
description="Generates a foreground mask to remove the background from an image using a background removal model.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Input image to remove the background from"),
|
||||
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
|
||||
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask"),
|
||||
IO.Image.Input("image", tooltip="Input image to remove the background from")
|
||||
],
|
||||
outputs=[
|
||||
IO.Mask.Output("mask", tooltip="Generated foreground mask")
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, image, bg_removal_model):
|
||||
def execute(cls, bg_removal_model, image):
|
||||
mask = bg_removal_model.encode_image(image)
|
||||
return IO.NodeOutput(mask)
|
||||
|
||||
|
||||
@ -7,29 +7,29 @@ class ColorToRGBInt(io.ComfyNode):
|
||||
def define_schema(cls) -> io.Schema:
|
||||
return io.Schema(
|
||||
node_id="ColorToRGBInt",
|
||||
display_name="Color to RGB Int",
|
||||
display_name="Color Picker",
|
||||
category="utilities",
|
||||
description="Convert a color to a RGB integer value.",
|
||||
description="Return a color RGB integer value and hexadecimal representation.",
|
||||
inputs=[
|
||||
io.Color.Input("color"),
|
||||
],
|
||||
outputs=[
|
||||
io.Int.Output(display_name="rgb_int"),
|
||||
io.Color.Output(display_name="hex")
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
color: str,
|
||||
) -> io.NodeOutput:
|
||||
def execute(cls, color: str) -> io.NodeOutput:
|
||||
# expect format #RRGGBB
|
||||
if len(color) != 7 or color[0] != "#":
|
||||
raise ValueError("Color must be in format #RRGGBB")
|
||||
r = int(color[1:3], 16)
|
||||
g = int(color[3:5], 16)
|
||||
b = int(color[5:7], 16)
|
||||
return io.NodeOutput(r * 256 * 256 + g * 256 + b)
|
||||
|
||||
rgb_int = r * 256 * 256 + g * 256 + b
|
||||
return io.NodeOutput(rgb_int, color)
|
||||
|
||||
|
||||
class ColorExtension(ComfyExtension):
|
||||
|
||||
@ -933,9 +933,10 @@ class Guider_DualModel(comfy.samplers.CFGGuider):
|
||||
|
||||
def predict_noise(self, x, timestep, model_options={}, seed=None):
|
||||
positive = self.conds.get("positive", None)
|
||||
if self.uncond_inner is None: # cfg == 1 or no negative -> single model, cond only
|
||||
return comfy.samplers.calc_cond_batch(self.inner_model, [positive], x, timestep, model_options)[0]
|
||||
cond = comfy.samplers.calc_cond_batch(self.inner_model, [positive], x, timestep, model_options)[0]
|
||||
# uncond model not loaded (base cfg==1/no negative), or cfg driven to 1.0 this step -> single model, cond only
|
||||
if self.uncond_inner is None or (math.isclose(self.cfg, 1.0) and not model_options.get("disable_cfg1_optimization", False)):
|
||||
return cond
|
||||
|
||||
uncond_model_options = model_options
|
||||
if "multigpu_clones" in model_options: # TODO: support multigpu instead of just running uncond on a single GPU
|
||||
@ -1140,7 +1141,7 @@ class CFGOverride(io.ComfyNode):
|
||||
return io.Schema(
|
||||
node_id="CFGOverride",
|
||||
display_name="CFG Override",
|
||||
description="Override cfg to a fixed value over a [start, end] percent slice of the steps. "
|
||||
description="Override cfg to a fixed value over a [start, end] percent (sigma) range. "
|
||||
"With multiple overrides, the one nearest the sampler wins on overlap.",
|
||||
category="sampling/custom_sampling",
|
||||
inputs=[
|
||||
|
||||
681
comfy_extras/nodes_depth_anything_3.py
Normal file
681
comfy_extras/nodes_depth_anything_3.py
Normal file
@ -0,0 +1,681 @@
|
||||
"""ComfyUI nodes for Depth Anything 3.
|
||||
Model capability matrix:
|
||||
|
||||
Variant head_type has_sky has_conf cam_dec
|
||||
DA3-Small dualdpt False True yes
|
||||
DA3-Base dualdpt False True yes
|
||||
DA3-Mono-Large dpt True False no
|
||||
DA3-Metric-Large dpt True False no (raw output is metres)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management as mm
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
from comfy.ldm.colormap import turbo as _turbo
|
||||
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
|
||||
from comfy_api.latest import ComfyExtension, Types, io
|
||||
from comfy.ldm.moge.geometry import triangulate_grid_mesh
|
||||
|
||||
DA3ModelType = io.Custom("DA3_MODEL")
|
||||
DA3Geometry = io.Custom("DA3_GEOMETRY")
|
||||
DA3PointCloud = io.Custom("DA3_POINT_CLOUD")
|
||||
|
||||
# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
|
||||
#
|
||||
# Per-frame tensors - B = batch size in mono mode; B = S (number of views) in multi-view mode.
|
||||
# "depth": torch.Tensor (B, H, W) -- raw model depth (always present; matches MoGe convention)
|
||||
# "image": torch.Tensor (B, H, W, 3) -- source image in [0, 1], CPU (always present)
|
||||
# "mode": str -- "mono" or "multiview" (always present)
|
||||
# "sky": torch.Tensor (B, H, W) -- sky probability in [0, 1] (Mono/Metric variants only)
|
||||
# "confidence": torch.Tensor (B, H, W) -- raw model confidence output (Small/Base variants only)
|
||||
#
|
||||
# Multi-view only - S = number of views; the leading 1 is the scene dimension from the model.
|
||||
# "extrinsics": torch.Tensor (1, S, 3, 4) -- world-to-camera [R|t] matrices
|
||||
# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics
|
||||
#
|
||||
# DA3_POINT_CLOUD is a dict:
|
||||
# "points": torch.Tensor (N, 3) -- 3-D coords in glTF convention (Y-up, Z-back)
|
||||
# "colors": torch.Tensor (N, 3) -- RGB in [0, 1], or None
|
||||
# "confidence": torch.Tensor (N,) -- raw confidence per point, or None
|
||||
|
||||
|
||||
def _da3_unproject(depth: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
|
||||
"""Pixel-space K⁻¹ unprojection: (H,W) depth → (H,W,3) point map in OpenCV space."""
|
||||
H, W = depth.shape
|
||||
u = torch.arange(W, dtype=torch.float32, device=depth.device)
|
||||
v = torch.arange(H, dtype=torch.float32, device=depth.device)
|
||||
u, v = torch.meshgrid(u, v, indexing='xy') # both (H, W)
|
||||
pix = torch.stack([u, v, torch.ones_like(u)], dim=-1) # (H, W, 3)
|
||||
rays = torch.einsum('ij,hwj->hwi', torch.linalg.inv(K.to(depth.device)), pix)
|
||||
return rays * depth.unsqueeze(-1) # (H, W, 3)
|
||||
|
||||
|
||||
def _da3_default_K(H: int, W: int) -> torch.Tensor:
|
||||
"""Fallback ~60° FOV pinhole K for mono-mode DA3 (no intrinsics in geometry)."""
|
||||
fx = fy = float(W) * 0.7
|
||||
return torch.tensor([[fx, 0.0, (W - 1) / 2.0],
|
||||
[0.0, fy, (H - 1) / 2.0],
|
||||
[0.0, 0.0, 1.0]], dtype=torch.float32)
|
||||
|
||||
|
||||
def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor:
|
||||
"""Return pixel-space K for batch element b, falling back to a default estimate."""
|
||||
if "intrinsics" in geometry:
|
||||
# shape (1, S, 3, 3) - leading scene dimension from the multiview head
|
||||
return geometry["intrinsics"][0, b].float()
|
||||
logging.getLogger("comfy").warning(
|
||||
"DA3_GEOMETRY has no intrinsics (mono-mode model). "
|
||||
"Using a ~60° FOV estimate; 3-D reconstruction may be inaccurate."
|
||||
)
|
||||
return _da3_default_K(H, W)
|
||||
|
||||
|
||||
def _da3_get_extrinsic(geometry: dict, b: int) -> torch.Tensor | None:
|
||||
"""Return the world-to-camera extrinsic for batch element b, or None in mono mode.
|
||||
|
||||
The model outputs (1, S, 3, 4) [R|t] matrices; the fallback identity is (4, 4).
|
||||
_da3_apply_extrinsic handles both shapes via [:3, :3] / [:3, 3] slicing.
|
||||
"""
|
||||
if "extrinsics" not in geometry:
|
||||
return None
|
||||
return geometry["extrinsics"][0, b].float()
|
||||
|
||||
|
||||
def _da3_apply_extrinsic(points_cam: torch.Tensor, E: torch.Tensor) -> torch.Tensor:
|
||||
"""Transform (H,W,3) OpenCV camera-space points to world space."""
|
||||
E = E.to(points_cam.device).float()
|
||||
if not torch.isfinite(E).all():
|
||||
logging.getLogger("comfy").warning(
|
||||
"DA3 extrinsic matrix contains non-finite values (pose estimation may have failed). "
|
||||
"Falling back to camera-space coordinates."
|
||||
)
|
||||
return points_cam
|
||||
H, W, _ = points_cam.shape
|
||||
R = E[:3, :3] # (3, 3) rotation
|
||||
t = E[:3, 3] # (3,) translation
|
||||
R_inv = R.T # rotation inverse = transpose for orthogonal R
|
||||
t_inv = -(R_inv @ t) # (3,)
|
||||
pts = points_cam.reshape(-1, 3) # (N, 3)
|
||||
pts_world = pts @ R_inv.T + t_inv # (N, 3)
|
||||
return pts_world.reshape(H, W, 3)
|
||||
|
||||
|
||||
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
|
||||
"""Map raw confidence to [0, 1] per image."""
|
||||
B = conf.shape[0]
|
||||
out = []
|
||||
for i in range(B):
|
||||
c = conf[i]
|
||||
c_min, c_max = c.min(), c.max()
|
||||
out.append((c - c_min) / (c_max - c_min) if c_max > c_min else torch.ones_like(c))
|
||||
return torch.stack(out, dim=0)
|
||||
|
||||
|
||||
def _da3_build_mask(geometry: dict, b: int, H: int, W: int, confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor:
|
||||
"""Build (H,W) bool keep-mask from sky probability and confidence."""
|
||||
mask = torch.ones(H, W, dtype=torch.bool)
|
||||
if use_sky_mask and "sky" in geometry:
|
||||
mask = mask & (geometry["sky"][b] < 0.5)
|
||||
if "confidence" in geometry and confidence_threshold > 0.0:
|
||||
conf_norm = _normalize_confidence(geometry["confidence"][b:b + 1])[0]
|
||||
mask = mask & (conf_norm >= confidence_threshold)
|
||||
return mask
|
||||
|
||||
|
||||
class LoadDA3Model(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadDA3Model",
|
||||
display_name="Load Depth Anything 3",
|
||||
category="model/loaders",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"model_name",
|
||||
options=folder_paths.get_filename_list("geometry_estimation"),
|
||||
),
|
||||
io.Combo.Input(
|
||||
"weight_dtype",
|
||||
options=["default", "fp16", "bf16", "fp32"],
|
||||
default="default",
|
||||
),
|
||||
],
|
||||
outputs=[DA3ModelType.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name, weight_dtype) -> io.NodeOutput:
|
||||
model_options = {}
|
||||
if weight_dtype == "fp16":
|
||||
model_options["dtype"] = torch.float16
|
||||
elif weight_dtype == "bf16":
|
||||
model_options["dtype"] = torch.bfloat16
|
||||
elif weight_dtype == "fp32":
|
||||
model_options["dtype"] = torch.float32
|
||||
|
||||
path = folder_paths.get_full_path_or_raise("geometry_estimation", model_name)
|
||||
model = comfy.sd.load_diffusion_model(path, model_options=model_options)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
def _run_da3(model_patcher, image: torch.Tensor, process_res: int, method: str = "upper_bound_resize"):
|
||||
"""Run DA3 on (B,H,W,3), returns depth/conf/sky at original resolution (or None)."""
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
|
||||
B, H, W, _ = image.shape
|
||||
mm.load_model_gpu(model_patcher)
|
||||
diffusion = model_patcher.model.diffusion_model
|
||||
device = mm.get_torch_device()
|
||||
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
|
||||
|
||||
depths, confs, skies = [], [], []
|
||||
for i in range(B):
|
||||
single = image[i:i + 1].to(device)
|
||||
x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method)
|
||||
x = x.to(dtype=dtype)
|
||||
with torch.no_grad():
|
||||
out = diffusion(x)
|
||||
|
||||
depth_lr = out["depth"]
|
||||
depth_full = torch.nn.functional.interpolate(
|
||||
depth_lr.unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
depths.append(depth_full)
|
||||
|
||||
if "depth_conf" in out:
|
||||
conf_full = torch.nn.functional.interpolate(
|
||||
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
confs.append(conf_full)
|
||||
if "sky" in out:
|
||||
sky_full = torch.nn.functional.interpolate(
|
||||
out["sky"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
skies.append(sky_full)
|
||||
|
||||
depth = torch.cat(depths, dim=0)
|
||||
confidence = torch.cat(confs, dim=0) if confs else None
|
||||
sky = torch.cat(skies, dim=0) if skies else None
|
||||
return depth, confidence, sky
|
||||
|
||||
|
||||
class DA3Inference(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3Inference",
|
||||
search_aliases=["depth", "geometry", "da3", "depth anything", "monocular", "pointmap", "sky", "3d", "metric depth", "disparity"],
|
||||
display_name="Run Depth Anything 3",
|
||||
category="image/geometry estimation",
|
||||
description="Run Depth Anything 3 on an image. In multi-view mode each image is treated as a separate view of the same scene.",
|
||||
inputs=[
|
||||
DA3ModelType.Input("da3_model"),
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("resolution", default=504, min=140, max=2520, step=14,
|
||||
tooltip="Resolution the model runs at (longest side, multiple of 14).\n"
|
||||
"Lower = faster / less VRAM.\n"
|
||||
"Higher = more detail.\n"
|
||||
"Output is upsampled back to the original size."),
|
||||
io.Combo.Input("resize_method", options=["upper_bound_resize", "lower_bound_resize"], default="upper_bound_resize",
|
||||
tooltip="upper_bound_resize: scale so the longest side = resolution (caps memory, default).\n"
|
||||
"lower_bound_resize: scale so the shortest side = resolution (preserves more detail on tall/wide images, uses more memory)."),
|
||||
io.DynamicCombo.Input("mode", tooltip="mono: single view image (works with any model variant).\n"
|
||||
"multiview: all images processed together for geometric consistency + camera pose (for Small/Base models only).",
|
||||
options=[
|
||||
io.DynamicCombo.Option("mono", []),
|
||||
io.DynamicCombo.Option("multiview", [
|
||||
io.Combo.Input("ref_view_strategy", options=["saddle_balanced", "saddle_sim_range", "first", "middle"], default="saddle_balanced",
|
||||
tooltip="Which view acts as the geometric anchor.\n"
|
||||
"- saddle_balanced: the view most 'average' across all others (best general choice).\n"
|
||||
"- saddle_sim_range: the view most visually distinct from the others.\n"
|
||||
"- first / middle: fixed positional picks."),
|
||||
io.Combo.Input("pose_method", options=["cam_dec", "ray_pose"], default="cam_dec",
|
||||
tooltip="How the camera field-of-view is estimated (for Small/Base models only).\n"
|
||||
"- cam_dec: learned from image features.\n"
|
||||
"- ray_pose: derived geometrically from the model's 3D ray output.\n"
|
||||
"Affects perspective correctness of the 3D output. Try both if results look distorted."),
|
||||
]),
|
||||
]),
|
||||
],
|
||||
outputs=[
|
||||
DA3Geometry.Output("da3_geometry", tooltip="Dictionary of non-normalized tensors.\n"
|
||||
"Always has the keys: depth, image, mode.\n"
|
||||
"Optional keys: sky (for Mono/Metric), confidence (for Small/Base), extrinsics + intrinsics (for multi-view)."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_model, image, resolution, resize_method, mode) -> io.NodeOutput:
|
||||
mode_val = mode["mode"] # "mono" or "multiview"
|
||||
|
||||
if mode_val == "mono":
|
||||
return cls._execute_mono(da3_model, image, resolution, resize_method)
|
||||
|
||||
# Capability checks for multi-view mode.
|
||||
diffusion = da3_model.model.diffusion_model
|
||||
pose_method = mode["pose_method"]
|
||||
ref_view_strategy = mode["ref_view_strategy"]
|
||||
|
||||
has_cam_dec = diffusion.cam_dec is not None
|
||||
has_dualdpt = diffusion.head_type == "dualdpt"
|
||||
|
||||
if not has_cam_dec and not has_dualdpt:
|
||||
raise ValueError(
|
||||
"multi-view mode requires Small or Base model. The loaded model "
|
||||
f"(head_type='{diffusion.head_type}') does not support cross-view "
|
||||
"attention or camera pose estimation. Switch mode to 'mono', or "
|
||||
"load Small or Base model for mult-view."
|
||||
)
|
||||
|
||||
if pose_method == "cam_dec" and not has_cam_dec:
|
||||
raise ValueError(
|
||||
"pose_method='cam_dec' requires a camera decoder, but the loaded "
|
||||
f"model (head_type='{diffusion.head_type}') does not have one. "
|
||||
"Use pose_method='ray_pose' instead."
|
||||
)
|
||||
if pose_method == "ray_pose" and not has_dualdpt:
|
||||
raise ValueError(
|
||||
"pose_method='ray_pose' requires a DualDPT head, but the loaded "
|
||||
f"model has a '{diffusion.head_type}' head. "
|
||||
"Use pose_method='cam_dec' instead."
|
||||
)
|
||||
|
||||
return cls._execute_multiview(
|
||||
da3_model, image, resolution, resize_method,
|
||||
ref_view_strategy, pose_method,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _execute_mono(cls, model, image, resolution, resize_method) -> io.NodeOutput:
|
||||
depth, confidence, sky = _run_da3(model, image, resolution, method=resize_method)
|
||||
|
||||
geometry: dict = {
|
||||
"depth": depth.contiguous(),
|
||||
"image": image[..., :3].cpu(),
|
||||
"mode": "mono",
|
||||
}
|
||||
if sky is not None:
|
||||
geometry["sky"] = sky.contiguous()
|
||||
if confidence is not None:
|
||||
geometry["confidence"] = confidence.contiguous()
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
@classmethod
|
||||
def _execute_multiview(cls, model, image, resolution, resize_method, ref_view_strategy, pose_method) -> io.NodeOutput:
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, \
|
||||
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
S, H, W, _ = image.shape
|
||||
|
||||
mm.load_model_gpu(model)
|
||||
diffusion = model.model.diffusion_model
|
||||
device = mm.get_torch_device()
|
||||
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
|
||||
|
||||
# All views in a single forward pass: (1, S, 3, H', W').
|
||||
x = image.to(device)
|
||||
x = da3_preprocess.preprocess_image(x, process_res=resolution, method=resize_method)
|
||||
x = x.to(dtype=dtype).unsqueeze(0)
|
||||
|
||||
use_ray_pose = (pose_method == "ray_pose")
|
||||
with torch.no_grad():
|
||||
out = diffusion(x, use_ray_pose=use_ray_pose, ref_view_strategy=ref_view_strategy)
|
||||
|
||||
depth = torch.nn.functional.interpolate(
|
||||
out["depth"].float().unsqueeze(1), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
sky = None
|
||||
if "sky" in out:
|
||||
sky = torch.nn.functional.interpolate(
|
||||
out["sky"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
if "extrinsics" in out and "intrinsics" in out:
|
||||
extrinsics = out["extrinsics"].float().cpu()
|
||||
intrinsics = out["intrinsics"].float().cpu()
|
||||
else:
|
||||
extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone()
|
||||
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
|
||||
|
||||
geometry: dict = {
|
||||
"depth": depth.contiguous(),
|
||||
"image": image[..., :3].cpu(),
|
||||
"mode": "multiview",
|
||||
"extrinsics": extrinsics.contiguous(),
|
||||
"intrinsics": intrinsics.contiguous(),
|
||||
}
|
||||
if sky is not None:
|
||||
geometry["sky"] = sky.contiguous()
|
||||
if "depth_conf" in out:
|
||||
conf = torch.nn.functional.interpolate(
|
||||
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
geometry["confidence"] = conf.contiguous()
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
|
||||
class DA3Render(io.ComfyNode):
|
||||
"""Render a visualization from a DA3_GEOMETRY packet."""
|
||||
|
||||
_DEPTH_RENDER_INPUTS = [
|
||||
io.Combo.Input("normalization",
|
||||
options=["v2_style", "min_max", "raw"],
|
||||
default="v2_style",
|
||||
tooltip="- v2_style: mean/std normalisation for perceptually balanced results (default).\n"
|
||||
"- min_max: stretches the full depth range to [0, 1] for maximum contrast.\n"
|
||||
"- raw: no scaling,preserves metric units for Metric model."),
|
||||
io.Boolean.Input("apply_sky_clip", default=False,
|
||||
tooltip="Clip sky-region depth to the 99th percentile of foreground depth before normalisation. "
|
||||
"Requires a sky key in the da3_geometry input (for Mono/Metric models only)."),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3Render",
|
||||
display_name="Render Depth Anything 3",
|
||||
category="image/geometry estimation",
|
||||
description="Render a depth map, confidence map, or sky mask from Depth Anything 3 geometry data.",
|
||||
inputs=[
|
||||
DA3Geometry.Input("da3_geometry"),
|
||||
io.DynamicCombo.Input("output",
|
||||
tooltip="- depth: normalised greyscale depth image.\n"
|
||||
"- depth_colored: depth mapped through the Turbo colormap.\n"
|
||||
"- sky_mask: sky probability in [0, 1] (for Mono/Metric models only).\n"
|
||||
"- confidence: normalised depth confidence (for Small/Base models only).",
|
||||
options=[
|
||||
io.DynamicCombo.Option("depth", cls._DEPTH_RENDER_INPUTS),
|
||||
io.DynamicCombo.Option("depth_colored", cls._DEPTH_RENDER_INPUTS),
|
||||
io.DynamicCombo.Option("sky_mask", [
|
||||
io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the sky mask."),
|
||||
]),
|
||||
io.DynamicCombo.Option("confidence", [
|
||||
io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the confidence map."),
|
||||
]),
|
||||
]),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_geometry, output) -> io.NodeOutput:
|
||||
output_val = output["output"]
|
||||
|
||||
if output_val in ("depth", "depth_colored"):
|
||||
normalization = output["normalization"]
|
||||
apply_sky_clip = output["apply_sky_clip"]
|
||||
if apply_sky_clip and "sky" not in da3_geometry:
|
||||
raise ValueError(
|
||||
"apply_sky_clip=True requires a sky tensor in the da3_geometry input, but none is present. "
|
||||
"Run with Mono/Metric models or set apply_sky_clip=False."
|
||||
)
|
||||
depth = da3_geometry["depth"]
|
||||
sky = da3_geometry.get("sky")
|
||||
if apply_sky_clip and sky is not None:
|
||||
depth = torch.stack([
|
||||
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
|
||||
for i in range(depth.shape[0])
|
||||
], dim=0)
|
||||
grey = cls._depth_to_image(depth, sky, normalization) # (B,H,W,3) greyscale
|
||||
result = _turbo(grey[..., 0]) if output_val == "depth_colored" else grey
|
||||
|
||||
elif output_val == "sky_mask":
|
||||
if "sky" not in da3_geometry:
|
||||
raise ValueError("geometry has no sky output; run with Mono/Metric models.")
|
||||
sky = da3_geometry["sky"]
|
||||
if output["colored"]:
|
||||
result = _turbo(sky)
|
||||
else:
|
||||
result = sky.unsqueeze(-1).expand(*sky.shape, 3).contiguous()
|
||||
|
||||
elif output_val == "confidence":
|
||||
if "confidence" not in da3_geometry:
|
||||
raise ValueError("da3_geometry has no confidence output; run with Small/Base models.")
|
||||
conf = _normalize_confidence(da3_geometry["confidence"])
|
||||
if output["colored"]:
|
||||
result = _turbo(conf)
|
||||
else:
|
||||
result = conf.unsqueeze(-1).expand(*conf.shape, 3).contiguous()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown output mode: {output_val}")
|
||||
|
||||
return io.NodeOutput(result.float())
|
||||
|
||||
@staticmethod
|
||||
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None, normalization: str) -> torch.Tensor:
|
||||
"""Normalise depth and pack as an (B,H,W,3) image tensor."""
|
||||
|
||||
N = depth.shape[0]
|
||||
if normalization == "v2_style":
|
||||
norm = torch.stack([
|
||||
da3_preprocess.normalize_depth_v2_style(
|
||||
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
|
||||
for i in range(N)
|
||||
], dim=0)
|
||||
elif normalization == "min_max":
|
||||
norm = da3_preprocess.normalize_depth_min_max(depth)
|
||||
else:
|
||||
norm = depth
|
||||
|
||||
out = norm.unsqueeze(-1).repeat(1, 1, 1, 3)
|
||||
if normalization != "raw":
|
||||
out = out.clamp(0.0, 1.0)
|
||||
return out.contiguous()
|
||||
|
||||
|
||||
class DA3GeometryToMesh(io.ComfyNode):
|
||||
"""Convert a DA3_GEOMETRY packet into a Types.MESH by unprojecting depth and triangulating."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3GeometryToMesh",
|
||||
search_aliases=["da3", "depth anything", "mesh", "geometry", "3d", "triangulate"],
|
||||
display_name="Convert DA3 Geometry to Mesh",
|
||||
category="image/geometry estimation",
|
||||
description="Convert a depth map into a triangulated 3D mesh.",
|
||||
inputs=[
|
||||
DA3Geometry.Input("da3_geometry"),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert. Per-image vertex counts differ so batches cannot be stacked."),
|
||||
io.Int.Input("decimation", default=1, min=1, max=8, tooltip="Vertex stride. 1 = full resolution, 2 = half, etc."),
|
||||
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01, tooltip="Drop triangles whose 3x3 depth span exceeds this fraction. 0 = off."),
|
||||
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all, 1 = keep only the single most confident pixel). "
|
||||
"Used when the geometry has a confidence map (Small/Base models)."),
|
||||
io.Boolean.Input("use_sky_mask", default=True, tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. Used when the geometry has a sky map (Mono/Metric models)."),
|
||||
io.Boolean.Input("texture", default=True, tooltip="Use the source image as a base color texture."),
|
||||
],
|
||||
outputs=[io.Mesh.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold, confidence_threshold, use_sky_mask, texture) -> io.NodeOutput:
|
||||
depth_all = da3_geometry["depth"] # (B, H, W)
|
||||
B = depth_all.shape[0]
|
||||
if batch_index >= B:
|
||||
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
|
||||
|
||||
depth = depth_all[batch_index] # (H, W)
|
||||
H, W = depth.shape
|
||||
|
||||
# NaN/inf depth would propagate silently through unproject and produce an
|
||||
# empty mesh; replace them with 0 here so those pixels are later excluded
|
||||
# by the isfinite check inside triangulate_grid_mesh.
|
||||
depth = depth.clone()
|
||||
n_bad = (~torch.isfinite(depth)).sum().item()
|
||||
if n_bad:
|
||||
logging.getLogger("comfy").warning(
|
||||
f"DA3GeometryToMesh: depth[{batch_index}] has {n_bad} non-finite pixels "
|
||||
f"({100*n_bad/(H*W):.1f}%) - zeroed before unproject."
|
||||
)
|
||||
depth[~torch.isfinite(depth)] = 0.0
|
||||
logging.getLogger("comfy").debug(
|
||||
f"DA3GeometryToMesh: depth[{batch_index}] range "
|
||||
f"[{depth.min():.4g}, {depth.max():.4g}], mean={depth.mean():.4g}"
|
||||
)
|
||||
|
||||
K = _da3_get_K(da3_geometry, batch_index, H, W)
|
||||
points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV camera space
|
||||
|
||||
# Apply world-to-camera inverse so multi-view frames share a common world frame.
|
||||
E = _da3_get_extrinsic(da3_geometry, batch_index)
|
||||
if E is not None:
|
||||
points = _da3_apply_extrinsic(points, E)
|
||||
|
||||
# Mask invalid pixels by setting them to inf so triangulate_grid_mesh skips them.
|
||||
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
|
||||
# Also exclude pixels where depth was invalid.
|
||||
mask = mask & (depth_all[batch_index] > 0) & torch.isfinite(depth_all[batch_index])
|
||||
points = points.clone()
|
||||
points[~mask] = float('inf')
|
||||
|
||||
verts, faces, uvs = triangulate_grid_mesh(
|
||||
points,
|
||||
decimation=decimation,
|
||||
discontinuity_threshold=discontinuity_threshold,
|
||||
depth=depth,
|
||||
)
|
||||
if verts.shape[0] == 0 or faces.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"DA3GeometryToMesh produced an empty mesh. "
|
||||
"Try raising discontinuity_threshold, lowering confidence_threshold, "
|
||||
"or disabling use_sky_mask."
|
||||
)
|
||||
|
||||
# OpenCV (X right, Y down, Z forward) → glTF (X right, Y up, Z back).
|
||||
# Same transform as MoGePointMapToMesh perspective branch.
|
||||
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
|
||||
faces = faces[:, [0, 2, 1]].contiguous()
|
||||
|
||||
tex = da3_geometry["image"][batch_index:batch_index + 1] if texture else None
|
||||
mesh = Types.MESH(
|
||||
vertices=verts.unsqueeze(0),
|
||||
faces=faces.unsqueeze(0),
|
||||
uvs=uvs.unsqueeze(0),
|
||||
texture=tex,
|
||||
)
|
||||
return io.NodeOutput(mesh)
|
||||
|
||||
|
||||
class DA3GeometryToPointCloud(io.ComfyNode):
|
||||
"""Unproject a DA3_GEOMETRY depth map into a filtered DA3_POINT_CLOUD."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3GeometryToPointCloud",
|
||||
search_aliases=["da3", "depth anything", "point cloud", "pointcloud", "3d", "geometry"],
|
||||
display_name="Convert DA3 Geometry to Point Cloud",
|
||||
category="image/geometry estimation",
|
||||
description="Convert a depth map into a 3D point cloud.",
|
||||
inputs=[
|
||||
DA3Geometry.Input("da3_geometry"),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert."),
|
||||
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all). Used when the geometry has a confidence map (Small/Base models)."),
|
||||
io.Boolean.Input("use_sky_mask", default=True,
|
||||
tooltip="Exclude sky-probability pixels (sky >= 0.5). Used when the geometry has a sky map (Mono/Metric models)."),
|
||||
io.Int.Input("downsample", default=1, min=1, max=16,
|
||||
tooltip="Take every Nth pixel (1 = full resolution). Higher values give fewer points and faster processing."),
|
||||
],
|
||||
# TODO: add a proper PointCloud output type
|
||||
outputs=[DA3PointCloud.Output(display_name="point_cloud")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_geometry, batch_index, confidence_threshold, use_sky_mask, downsample) -> io.NodeOutput:
|
||||
depth_all = da3_geometry["depth"] # (B, H, W)
|
||||
B = depth_all.shape[0]
|
||||
if batch_index >= B:
|
||||
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
|
||||
|
||||
depth = depth_all[batch_index].clone() # (H, W)
|
||||
depth[~torch.isfinite(depth)] = 0.0
|
||||
H, W = depth.shape
|
||||
|
||||
K = _da3_get_K(da3_geometry, batch_index, H, W)
|
||||
|
||||
if downsample > 1:
|
||||
depth = depth[::downsample, ::downsample].contiguous()
|
||||
# Scale intrinsics to the downsampled grid.
|
||||
K = K.clone()
|
||||
K[0, :] /= downsample
|
||||
K[1, :] /= downsample
|
||||
|
||||
H_ds, W_ds = depth.shape
|
||||
points = _da3_unproject(depth, K) # (H_ds, W_ds, 3) in OpenCV camera space
|
||||
|
||||
# Apply world-to-camera inverse so multi-view frames share a common world frame.
|
||||
E = _da3_get_extrinsic(da3_geometry, batch_index)
|
||||
if E is not None:
|
||||
points = _da3_apply_extrinsic(points, E)
|
||||
|
||||
# Rebuild mask at downsampled resolution.
|
||||
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
|
||||
if downsample > 1:
|
||||
mask = mask[::downsample, ::downsample]
|
||||
|
||||
mask = mask & torch.isfinite(depth)
|
||||
|
||||
# OpenCV → glTF: flip Y and Z.
|
||||
points_gltf = points.clone()
|
||||
points_gltf[..., 1] *= -1.0
|
||||
points_gltf[..., 2] *= -1.0
|
||||
|
||||
pts_flat = points_gltf.reshape(-1, 3)[mask.reshape(-1)]
|
||||
|
||||
colors_flat = None
|
||||
if "image" in da3_geometry:
|
||||
img = da3_geometry["image"][batch_index] # (H, W, 3)
|
||||
if downsample > 1:
|
||||
img = img[::downsample, ::downsample]
|
||||
colors_flat = img.reshape(-1, 3)[mask.reshape(-1)]
|
||||
|
||||
conf_flat = None
|
||||
if "confidence" in da3_geometry:
|
||||
conf = da3_geometry["confidence"][batch_index] # (H, W)
|
||||
if downsample > 1:
|
||||
conf = conf[::downsample, ::downsample]
|
||||
conf_flat = conf.reshape(-1)[mask.reshape(-1)]
|
||||
|
||||
if pts_flat.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"DA3GeometryToPointCloud produced zero points after filtering. "
|
||||
"Try lowering confidence_threshold or disabling use_sky_mask."
|
||||
)
|
||||
|
||||
return io.NodeOutput({
|
||||
"points": pts_flat,
|
||||
"colors": colors_flat,
|
||||
"confidence": conf_flat,
|
||||
})
|
||||
|
||||
|
||||
class DA3Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LoadDA3Model,
|
||||
DA3Inference,
|
||||
DA3Render,
|
||||
DA3GeometryToMesh,
|
||||
# DA3GeometryToPointCloud, # Keep this commented out for now until we have a proper PointCloud output type
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> DA3Extension:
|
||||
return DA3Extension()
|
||||
@ -245,6 +245,11 @@ class KV_Attn_Input:
|
||||
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
||||
if cache_key in self.cache:
|
||||
kk, vv = self.cache[cache_key]
|
||||
|
||||
# Fix batch size changing.
|
||||
kk = comfy.utils.repeat_to_batch_size(kk, k.shape[0])
|
||||
vv = comfy.utils.repeat_to_batch_size(vv, v.shape[0])
|
||||
|
||||
self.set_cache = False
|
||||
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
||||
|
||||
|
||||
@ -488,7 +488,7 @@ class SplatToFile3D(IO.ComfyNode):
|
||||
"spz: Niantic gzip-compressed (~10x smaller), base color only "
|
||||
),
|
||||
],
|
||||
outputs=[IO.File3DAny.Output(display_name="model_3d")],
|
||||
outputs=[IO.File3DSplatAny.Output(display_name="model_3d")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -516,7 +516,7 @@ class File3DToSplat(IO.ComfyNode):
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
IO.File3DAny.Input("model_3d"),
|
||||
types=[IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ],
|
||||
types=[IO.File3DSplatAny, IO.File3DPLY, IO.File3DSPLAT, IO.File3DKSPLAT, IO.File3DSPZ],
|
||||
tooltip="A gaussian splat 3D file",
|
||||
),
|
||||
],
|
||||
|
||||
@ -51,6 +51,14 @@ class Load3D(IO.ComfyNode):
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def validate_inputs(cls, model_file, **kwargs) -> bool | str:
|
||||
if not model_file or model_file == "none":
|
||||
return True
|
||||
if not folder_paths.exists_annotated_filepath(model_file):
|
||||
return f"Invalid 3D model file: {model_file}"
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file, image, **kwargs) -> IO.NodeOutput:
|
||||
image_path = folder_paths.get_annotated_filepath(image['image'])
|
||||
@ -136,7 +144,7 @@ class Preview3DAdvanced(IO.ComfyNode):
|
||||
is_output_node=True,
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_file",
|
||||
"model_3d",
|
||||
types=[
|
||||
IO.File3DGLB,
|
||||
IO.File3DGLTF,
|
||||
@ -148,34 +156,161 @@ class Preview3DAdvanced(IO.ComfyNode):
|
||||
],
|
||||
tooltip="3D model file from an upstream 3D node.",
|
||||
),
|
||||
IO.Load3D.Input("image"),
|
||||
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
|
||||
IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True),
|
||||
IO.Load3D.Input("viewport_state"),
|
||||
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
|
||||
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DAny.Output(display_name="model_file"),
|
||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||
IO.File3DAny.Output(display_name="model_3d"),
|
||||
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
|
||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||
IO.Int.Output(display_name="width"),
|
||||
IO.Int.Output(display_name="height"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_file: Types.File3D, image, width: int, height: int, **kwargs) -> IO.NodeOutput:
|
||||
filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_file.format}"
|
||||
model_file.save_to(os.path.join(folder_paths.get_output_directory(), filename))
|
||||
def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
|
||||
filename = f"preview3d_advanced_{uuid.uuid4().hex}.{model_3d.format}"
|
||||
model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename))
|
||||
|
||||
camera_info_input = kwargs.get("camera_info", None)
|
||||
camera_info = camera_info_input if camera_info_input is not None else image['camera_info']
|
||||
camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info']
|
||||
model_3d_info_input = kwargs.get("model_3d_info", None)
|
||||
model_3d_info = model_3d_info_input if model_3d_info_input is not None else image.get('model_3d_info', [])
|
||||
model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', [])
|
||||
return IO.NodeOutput(
|
||||
model_file,
|
||||
camera_info,
|
||||
model_3d,
|
||||
model_3d_info,
|
||||
camera_info,
|
||||
width,
|
||||
height,
|
||||
ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info),
|
||||
)
|
||||
|
||||
|
||||
class PreviewGaussianSplat(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PreviewGaussianSplat",
|
||||
display_name="Preview Splat",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
search_aliases=[
|
||||
"view splat",
|
||||
"view gaussian",
|
||||
"view gaussian splat",
|
||||
"preview gaussian",
|
||||
"preview gaussian splat",
|
||||
"view 3dgs",
|
||||
"preview 3dgs",
|
||||
"preview ply",
|
||||
"preview spz",
|
||||
"preview splat",
|
||||
"preview ksplat",
|
||||
],
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[
|
||||
IO.File3DSplatAny,
|
||||
IO.File3DPLY,
|
||||
IO.File3DSPLAT,
|
||||
IO.File3DSPZ,
|
||||
IO.File3DKSPLAT,
|
||||
],
|
||||
tooltip="A gaussian splat 3D file.",
|
||||
),
|
||||
IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True),
|
||||
IO.Load3D.Input("viewport_state"),
|
||||
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
|
||||
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DSplatAny.Output(display_name="model_3d"),
|
||||
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
|
||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||
IO.Int.Output(display_name="width"),
|
||||
IO.Int.Output(display_name="height"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
|
||||
filename = f"preview_splat_{uuid.uuid4().hex}.{model_3d.format}"
|
||||
model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename))
|
||||
|
||||
camera_info_input = kwargs.get("camera_info", None)
|
||||
camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info']
|
||||
model_3d_info_input = kwargs.get("model_3d_info", None)
|
||||
model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', [])
|
||||
return IO.NodeOutput(
|
||||
model_3d,
|
||||
model_3d_info,
|
||||
camera_info,
|
||||
width,
|
||||
height,
|
||||
ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info),
|
||||
)
|
||||
|
||||
|
||||
class PreviewPointCloud(IO.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="PreviewPointCloud",
|
||||
display_name="Preview Point Cloud",
|
||||
category="3d",
|
||||
is_experimental=True,
|
||||
is_output_node=True,
|
||||
search_aliases=[
|
||||
"view point cloud",
|
||||
"view pointcloud",
|
||||
"preview point cloud",
|
||||
"preview pointcloud",
|
||||
"preview ply",
|
||||
],
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[
|
||||
IO.File3DPointCloudAny,
|
||||
IO.File3DPLY,
|
||||
],
|
||||
tooltip="Point cloud file (.ply)",
|
||||
),
|
||||
IO.Load3DModelInfo.Input("model_3d_info", optional=True, advanced=True),
|
||||
IO.Load3D.Input("viewport_state"),
|
||||
IO.Load3DCamera.Input("camera_info", optional=True, advanced=True),
|
||||
IO.Int.Input("width", default=1024, min=1, max=4096, step=1),
|
||||
IO.Int.Input("height", default=1024, min=1, max=4096, step=1),
|
||||
],
|
||||
outputs=[
|
||||
IO.File3DPointCloudAny.Output(display_name="model_3d"),
|
||||
IO.Load3DModelInfo.Output(display_name="model_3d_info"),
|
||||
IO.Load3DCamera.Output(display_name="camera_info"),
|
||||
IO.Int.Output(display_name="width"),
|
||||
IO.Int.Output(display_name="height"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_3d: Types.File3D, viewport_state, width: int, height: int, **kwargs) -> IO.NodeOutput:
|
||||
filename = f"preview_pointcloud_{uuid.uuid4().hex}.{model_3d.format}"
|
||||
model_3d.save_to(os.path.join(folder_paths.get_temp_directory(), filename))
|
||||
|
||||
camera_info_input = kwargs.get("camera_info", None)
|
||||
camera_info = camera_info_input if camera_info_input is not None else viewport_state['camera_info']
|
||||
model_3d_info_input = kwargs.get("model_3d_info", None)
|
||||
model_3d_info = model_3d_info_input if model_3d_info_input is not None else viewport_state.get('model_3d_info', [])
|
||||
return IO.NodeOutput(
|
||||
model_3d,
|
||||
model_3d_info,
|
||||
camera_info,
|
||||
width,
|
||||
height,
|
||||
ui=UI.PreviewUI3DAdvanced(filename, camera_info, model_3d_info),
|
||||
@ -189,6 +324,8 @@ class Load3DExtension(ComfyExtension):
|
||||
Load3D,
|
||||
Preview3D,
|
||||
Preview3DAdvanced,
|
||||
PreviewGaussianSplat,
|
||||
PreviewPointCloud,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import folder_paths
|
||||
from comfy_api.latest import ComfyExtension, Types, io
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.ldm.colormap import turbo as _turbo
|
||||
from comfy.ldm.moge.model import MoGeModel
|
||||
from comfy.ldm.moge.geometry import triangulate_grid_mesh
|
||||
from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid
|
||||
@ -27,19 +28,6 @@ MoGeGeometry = io.Custom("MOGE_GEOMETRY")
|
||||
# "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present)
|
||||
|
||||
|
||||
def _turbo(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Anton Mikhailov polynomial approximation of the turbo colormap."""
|
||||
x = x.clamp(0.0, 1.0)
|
||||
x2 = x * x
|
||||
x3 = x2 * x
|
||||
x4 = x2 * x2
|
||||
x5 = x4 * x
|
||||
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
|
||||
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
|
||||
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
|
||||
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
|
||||
|
||||
|
||||
def _normals_from_points(points: torch.Tensor) -> torch.Tensor:
|
||||
"""Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback)."""
|
||||
finite = torch.isfinite(points).all(dim=-1)
|
||||
|
||||
@ -6,24 +6,24 @@ from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
class AspectRatio(str, Enum):
|
||||
SQUARE = "1:1 (Square)"
|
||||
PHOTO_V = "2:3 (Portrait Photo)"
|
||||
PHOTO_H = "3:2 (Photo)"
|
||||
STANDARD_V = "3:4 (Portrait Standard)"
|
||||
STANDARD_H = "4:3 (Standard)"
|
||||
WIDESCREEN_V = "9:16 (Portrait Widescreen)"
|
||||
WIDESCREEN_H = "16:9 (Widescreen)"
|
||||
ULTRAWIDE_H = "21:9 (Ultrawide)"
|
||||
PHOTO_V = "2:3 (Portrait Photo)"
|
||||
STANDARD_V = "3:4 (Portrait Standard)"
|
||||
WIDESCREEN_V = "9:16 (Portrait Widescreen)"
|
||||
|
||||
|
||||
ASPECT_RATIOS: dict[AspectRatio, tuple[int, int]] = {
|
||||
AspectRatio.SQUARE: (1, 1),
|
||||
AspectRatio.PHOTO_V: (2, 3),
|
||||
AspectRatio.PHOTO_H: (3, 2),
|
||||
AspectRatio.STANDARD_V: (3, 4),
|
||||
AspectRatio.STANDARD_H: (4, 3),
|
||||
AspectRatio.WIDESCREEN_V: (9, 16),
|
||||
AspectRatio.WIDESCREEN_H: (16, 9),
|
||||
AspectRatio.ULTRAWIDE_H: (21, 9),
|
||||
AspectRatio.PHOTO_V: (2, 3),
|
||||
AspectRatio.STANDARD_V: (3, 4),
|
||||
AspectRatio.WIDESCREEN_V: (9, 16),
|
||||
}
|
||||
|
||||
|
||||
@ -50,26 +50,35 @@ class ResolutionSelector(io.ComfyNode):
|
||||
min=0.1,
|
||||
max=16.0,
|
||||
step=0.1,
|
||||
tooltip="Target total megapixels. 1.0 MP ≈ 1024×1024 for square.",
|
||||
tooltip="Target total megapixels. 1.0 MP ≈ 1024x1024 for square.",
|
||||
),
|
||||
io.Int.Input(
|
||||
id="multiple",
|
||||
default=8,
|
||||
min=8,
|
||||
max=128,
|
||||
step=4,
|
||||
tooltip="Nearest multiple of the result to set the selected resolution to.",
|
||||
advanced=True,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Int.Output(
|
||||
"width", tooltip="Calculated width in pixels (multiple of 8)."
|
||||
"width", tooltip="Calculated width in pixels multiplied by the selected multiple."
|
||||
),
|
||||
io.Int.Output(
|
||||
"height", tooltip="Calculated height in pixels (multiple of 8)."
|
||||
"height", tooltip="Calculated height in pixels multiplied by the selected multiple."
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, aspect_ratio: str, megapixels: float) -> io.NodeOutput:
|
||||
def execute(cls, aspect_ratio: str, megapixels: float, multiple: int) -> io.NodeOutput:
|
||||
w_ratio, h_ratio = ASPECT_RATIOS[aspect_ratio]
|
||||
total_pixels = megapixels * 1024 * 1024
|
||||
scale = math.sqrt(total_pixels / (w_ratio * h_ratio))
|
||||
width = round(w_ratio * scale / 8) * 8
|
||||
height = round(h_ratio * scale / 8) * 8
|
||||
width = round(w_ratio * scale / multiple) * multiple
|
||||
height = round(h_ratio * scale / multiple) * multiple
|
||||
return io.NodeOutput(width, height)
|
||||
|
||||
|
||||
|
||||
@ -337,6 +337,12 @@ class SaveGLB(IO.ComfyNode):
|
||||
IO.File3DFBX,
|
||||
IO.File3DSTL,
|
||||
IO.File3DUSDZ,
|
||||
IO.File3DPLY,
|
||||
IO.File3DSPLAT,
|
||||
IO.File3DSPZ,
|
||||
IO.File3DKSPLAT,
|
||||
IO.File3DSplatAny,
|
||||
IO.File3DPointCloudAny,
|
||||
IO.File3DAny,
|
||||
],
|
||||
tooltip="Mesh or 3D file to save",
|
||||
|
||||
325
comfy_extras/nodes_scail.py
Normal file
325
comfy_extras/nodes_scail.py
Normal file
@ -0,0 +1,325 @@
|
||||
"""SCAIL / SCAIL-2 nodes: the WanSCAILToVideo conditioning node and the SAM3
|
||||
preprocessing that turns video tracks into the bundle the SCAIL-2 model consumes."""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import nodes
|
||||
import node_helpers
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy.ldm.sam3.tracker import unpack_masks
|
||||
|
||||
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
|
||||
|
||||
|
||||
# Model was trained on these exact colors; deviating degrades multi-identity quality.
|
||||
DEFAULT_PALETTE = [
|
||||
(0.0, 0.0, 1.0), # Blue
|
||||
(1.0, 0.0, 0.0), # Red
|
||||
(0.0, 1.0, 0.0), # Green
|
||||
(1.0, 0.0, 1.0), # Magenta
|
||||
(0.0, 1.0, 1.0), # Cyan
|
||||
(1.0, 1.0, 0.0), # Yellow
|
||||
]
|
||||
|
||||
|
||||
def _unpack(track_data):
|
||||
packed = track_data["packed_masks"]
|
||||
if packed is None or packed.shape[1] == 0:
|
||||
return None
|
||||
return unpack_masks(packed)
|
||||
|
||||
|
||||
def _first_frame_cx_area(masks_bool):
|
||||
first = masks_bool[0].float()
|
||||
H, W = first.shape[-2], first.shape[-1]
|
||||
n_pixels = H * W
|
||||
grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W)
|
||||
area = first.sum(dim=(-1, -2)).clamp_(min=1)
|
||||
cx = (first * grid_x).sum(dim=(-1, -2)) / area
|
||||
return (cx / W).tolist(), (area / n_pixels).tolist()
|
||||
|
||||
|
||||
def _subset_track_data(track_data, obj_indices):
|
||||
out = dict(track_data)
|
||||
packed = track_data["packed_masks"]
|
||||
if packed is None or not obj_indices:
|
||||
out["packed_masks"] = None
|
||||
if "scores" in out:
|
||||
out["scores"] = []
|
||||
return out
|
||||
out["packed_masks"] = packed[:, obj_indices].contiguous()
|
||||
scores = track_data.get("scores")
|
||||
if scores is not None:
|
||||
out["scores"] = [scores[i] for i in obj_indices if i < len(scores)]
|
||||
return out
|
||||
|
||||
|
||||
def _render_colored_masks(track_data, background="black"):
|
||||
packed = track_data["packed_masks"]
|
||||
H, W = track_data["orig_size"]
|
||||
device = comfy.model_management.intermediate_device()
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
|
||||
if packed is None or packed.shape[1] == 0:
|
||||
T = track_data.get("n_frames", 1) if packed is None else packed.shape[0]
|
||||
out = torch.empty(T, H, W, 3, device=device, dtype=dtype)
|
||||
out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2]
|
||||
return out
|
||||
T, N_obj = packed.shape[0], packed.shape[1]
|
||||
colors = torch.tensor(
|
||||
[DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)],
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
masks_full = unpack_masks(packed.to(device)).float()
|
||||
Hm, Wm = masks_full.shape[-2], masks_full.shape[-1]
|
||||
masks_full = F.interpolate(
|
||||
masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest"
|
||||
).view(T, N_obj, H, W) > 0.5
|
||||
any_mask = masks_full.any(dim=1)
|
||||
obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1)
|
||||
color_overlay = colors[obj_idx_map]
|
||||
bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3)
|
||||
return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay))
|
||||
|
||||
|
||||
def _extract_mask_to_28ch(rgb_video):
|
||||
"""Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent
|
||||
(1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c)
|
||||
threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking."""
|
||||
T, H, W, _ = rgb_video.shape
|
||||
_ON_THRESH = 225.0 / 255.0
|
||||
mask = rgb_video.movedim(-1, 1).float()
|
||||
R = (mask[:, 0:1] > _ON_THRESH).float()
|
||||
G = (mask[:, 1:2] > _ON_THRESH).float()
|
||||
B = (mask[:, 2:3] > _ON_THRESH).float()
|
||||
nR, nG, nB = 1 - R, 1 - G, 1 - B
|
||||
binary_7ch = torch.cat([
|
||||
R * G * B, # white
|
||||
R * nG * nB, # red
|
||||
nR * G * nB, # green
|
||||
nR * nG * B, # blue
|
||||
R * G * nB, # yellow
|
||||
R * nG * B, # magenta
|
||||
nR * G * B, # cyan
|
||||
], dim=1)
|
||||
H_lat, W_lat = H, W
|
||||
for _ in range(3):
|
||||
H_lat = (H_lat + 1) // 2
|
||||
W_lat = (W_lat + 1) // 2
|
||||
binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area')
|
||||
T_latent = (T - 1) // 4 + 1
|
||||
padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0)
|
||||
out = padded.view(T_latent, 28, H_lat, W_lat)
|
||||
return out.unsqueeze(0)
|
||||
|
||||
|
||||
class WanSCAILToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanSCAILToVideo",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
|
||||
io.Image.Input("pose_video_mask", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video."),
|
||||
io.Boolean.Input("replacement_mode", default=False, optional=True, tooltip="SCAIL-2 only. False = Animation Mode (pose_video_mask should have black background). True = Replacement Mode (pose_video_mask should have white background)."),
|
||||
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
|
||||
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."),
|
||||
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."),
|
||||
io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."),
|
||||
io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."),
|
||||
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."),
|
||||
io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."),
|
||||
io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the extension anchor."),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
|
||||
io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end,
|
||||
video_frame_offset, previous_frame_count, replacement_mode=False, reference_image=None, clip_vision_output=None, pose_video=None,
|
||||
pose_video_mask=None, reference_image_mask=None, previous_frames=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
noise_mask = None
|
||||
|
||||
ref_mask_flag = not replacement_mode
|
||||
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag})
|
||||
|
||||
prev_trimmed = None
|
||||
if previous_frames is not None and previous_frames.shape[0] > 0:
|
||||
prev_trimmed = previous_frames[-previous_frame_count:]
|
||||
video_frame_offset -= prev_trimmed.shape[0]
|
||||
video_frame_offset = max(0, video_frame_offset)
|
||||
|
||||
ref_latent = None
|
||||
if reference_image is not None:
|
||||
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
|
||||
# Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte
|
||||
if replacement_mode and reference_image_mask is not None:
|
||||
rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
|
||||
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype)
|
||||
reference_image = reference_image * is_char
|
||||
ref_latent = vae.encode(reference_image[:, :, :, :3])
|
||||
|
||||
if ref_latent is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
if pose_video is not None:
|
||||
if pose_video.shape[0] <= video_frame_offset:
|
||||
pose_video = None
|
||||
else:
|
||||
pose_video = pose_video[video_frame_offset:]
|
||||
if pose_video_mask is not None:
|
||||
if pose_video_mask.shape[0] <= video_frame_offset:
|
||||
pose_video_mask = None
|
||||
else:
|
||||
pose_video_mask = pose_video_mask[video_frame_offset:]
|
||||
|
||||
# Truncate pose+mask jointly to the shorter of the two, capped at length.
|
||||
ts = [v.shape[0] for v in (pose_video, pose_video_mask) if v is not None]
|
||||
if ts:
|
||||
T_kept = ((min(min(ts), length) - 1) // 4) * 4 + 1
|
||||
if pose_video is not None:
|
||||
pose_video = pose_video[:T_kept]
|
||||
if pose_video_mask is not None:
|
||||
pose_video_mask = pose_video_mask[:T_kept]
|
||||
|
||||
if pose_video is not None:
|
||||
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
|
||||
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
|
||||
if pose_video_mask is not None:
|
||||
mask_video_hw = comfy.utils.common_upscale(pose_video_mask[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||
driving_mask_28ch = _extract_mask_to_28ch(mask_video_hw)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch})
|
||||
|
||||
if reference_image_mask is not None:
|
||||
ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
|
||||
ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw)
|
||||
zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype)
|
||||
ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch})
|
||||
|
||||
if prev_trimmed is not None:
|
||||
pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
|
||||
prev_latent = vae.encode(pf[:, :, :, :3])
|
||||
prev_latent_frames = min(prev_latent.shape[2], latent.shape[2])
|
||||
latent[:, :, :prev_latent_frames] = prev_latent[:, :, :prev_latent_frames].to(latent.dtype)
|
||||
noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=latent.device, dtype=latent.dtype)
|
||||
noise_mask[:, :, :prev_latent_frames] = 0.0
|
||||
|
||||
out_latent = {"samples": latent}
|
||||
if noise_mask is not None:
|
||||
out_latent["noise_mask"] = noise_mask
|
||||
return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length)
|
||||
|
||||
|
||||
class SCAIL2ColoredMask(io.ComfyNode):
|
||||
"""Render SAM3 tracks for the driving pose video and (optionally) the reference
|
||||
image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by`
|
||||
across both outputs guarantees identity K maps to the same color on both
|
||||
sides, for multi-person workflow consistency.
|
||||
reference_image_mask is always rendered black-bg (model convention)
|
||||
pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SCAIL2ColoredMask",
|
||||
display_name="Create SCAIL-2 Colored Mask",
|
||||
category="conditioning/video_models/scail",
|
||||
inputs=[
|
||||
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."),
|
||||
SAM3TrackData.Input("ref_track_data", optional=True,
|
||||
tooltip="SAM3 track of the reference image."),
|
||||
io.String.Input("object_indices", default="",
|
||||
tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."),
|
||||
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
|
||||
tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."),
|
||||
io.Boolean.Input("replacement_mode", default=False,
|
||||
tooltip="False = Animation Mode (pose_video_mask has black background, reference_image_mask has white background). "
|
||||
"True = Replacement Mode (pose_video_mask has white background, reference_image_mask has black background)."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("pose_video_mask"),
|
||||
io.Image.Output("reference_image_mask"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None):
|
||||
def _prep(td):
|
||||
masks_bool = _unpack(td)
|
||||
if sort_by != "none" and masks_bool is not None:
|
||||
cx, area = _first_frame_cx_area(masks_bool)
|
||||
if sort_by == "left_to_right":
|
||||
order = sorted(range(len(cx)), key=lambda i: cx[i])
|
||||
else: # "area"
|
||||
order = sorted(range(len(area)), key=lambda i: -area[i])
|
||||
td = _subset_track_data(td, order)
|
||||
if object_indices.strip():
|
||||
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
|
||||
packed = td.get("packed_masks")
|
||||
n_obj = packed.shape[1] if packed is not None else 0
|
||||
indices = [i for i in indices if 0 <= i < n_obj]
|
||||
td = _subset_track_data(td, indices)
|
||||
return td
|
||||
|
||||
drv = _prep(driving_track_data)
|
||||
# Animation: driving=black, ref=white. Replacement: driving=white, ref=black.
|
||||
mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black")
|
||||
ref_bg = "black" if replacement_mode else "white"
|
||||
|
||||
if ref_track_data is not None:
|
||||
ref = _prep(ref_track_data)
|
||||
reference_image_mask = _render_colored_masks(ref, ref_bg)
|
||||
else:
|
||||
H, W = drv["orig_size"]
|
||||
fill_value = 1.0 if ref_bg == "white" else 0.0
|
||||
reference_image_mask = torch.full((1, H, W, 3), fill_value, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
|
||||
return io.NodeOutput(mask_video, reference_image_mask)
|
||||
|
||||
|
||||
class SCAILExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
WanSCAILToVideo,
|
||||
SCAIL2ColoredMask,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> SCAILExtension:
|
||||
return SCAILExtension()
|
||||
@ -15,6 +15,7 @@ import comfy.sampler_helpers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy.conds import CONDRegular, CONDList
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
@ -120,6 +121,11 @@ def process_cond_list(d, prefix=""):
|
||||
process_cond_list(v, f"{prefix}.{k}")
|
||||
elif isinstance(v, torch.Tensor):
|
||||
d[k] = v.clone()
|
||||
elif isinstance(v, CONDList):
|
||||
v.cond = [t.detach() if isinstance(t, torch.Tensor) else t for t in v.cond]
|
||||
elif isinstance(v, CONDRegular):
|
||||
if isinstance(v.cond, torch.Tensor):
|
||||
v.cond = v.cond.detach()
|
||||
elif isinstance(v, (list, tuple)):
|
||||
for index, item in enumerate(v):
|
||||
process_cond_list(item, f"{prefix}.{k}.{index}")
|
||||
@ -1143,45 +1149,45 @@ class TrainLoraNode(io.ComfyNode):
|
||||
# Process conditioning
|
||||
positive = _process_conditioning(positive)
|
||||
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
use_grad_scaler = False
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
else:
|
||||
# Detect model's native dtype for autocast
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
# GradScaler only supports float16 gradients, not bfloat16.
|
||||
# Only enable it when lora params will also be in float16.
|
||||
if lora_dtype != torch.bfloat16:
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||
)
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, latents_dtype, bucket_mode
|
||||
)
|
||||
|
||||
# Validate and expand conditioning
|
||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||
|
||||
with torch.inference_mode(False):
|
||||
# Setup model and dtype
|
||||
mp = model.clone(force_deepcopy=True)
|
||||
use_grad_scaler = False
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
else:
|
||||
# Detect model's native dtype for autocast
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
# GradScaler only supports float16 gradients, not bfloat16.
|
||||
# Only enable it when lora params will also be in float16.
|
||||
if lora_dtype != torch.bfloat16:
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||
)
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, latents_dtype, bucket_mode
|
||||
)
|
||||
|
||||
# Validate and expand conditioning
|
||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||
|
||||
# Setup models for training
|
||||
mp.model.requires_grad_(False)
|
||||
mp.model.requires_grad_(False).train()
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||
|
||||
@ -136,6 +136,17 @@ class CreateVideo(io.ComfyNode):
|
||||
io.Image.Input("images", tooltip="The images to create a video from."),
|
||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||
io.Int.Input(
|
||||
"bit_depth",
|
||||
min=8,
|
||||
max=10,
|
||||
default=8,
|
||||
step=2,
|
||||
tooltip="Bit depth of the created video. 10-bit keeps smoother gradients with less"
|
||||
" banding, but some players and downstream nodes may not support it.",
|
||||
optional=True,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
@ -143,9 +154,14 @@ class CreateVideo(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
|
||||
def execute(
|
||||
cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None, bit_depth: int = 8,
|
||||
) -> io.NodeOutput:
|
||||
return io.NodeOutput(
|
||||
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)),
|
||||
bit_depth=bit_depth,
|
||||
)
|
||||
)
|
||||
|
||||
class GetVideoComponents(io.ComfyNode):
|
||||
@ -156,7 +172,7 @@ class GetVideoComponents(io.ComfyNode):
|
||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||
display_name="Get Video Components",
|
||||
category="video",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
description="Extracts all components from a video: frames, audio, framerate, and bit depth.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||
],
|
||||
@ -164,13 +180,14 @@ class GetVideoComponents(io.ComfyNode):
|
||||
io.Image.Output(display_name="images"),
|
||||
io.Audio.Output(display_name="audio"),
|
||||
io.Float.Output(display_name="fps"),
|
||||
io.Int.Output(display_name="bit_depth"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
||||
components = video.get_components()
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate), video.get_bit_depth())
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
|
||||
@ -1456,63 +1456,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||
|
||||
|
||||
class WanSCAILToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanSCAILToVideo",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("reference_image", optional=True),
|
||||
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
|
||||
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
|
||||
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
|
||||
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
|
||||
ref_latent = None
|
||||
if reference_image is not None:
|
||||
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
ref_latent = vae.encode(reference_image[:, :, :, :3])
|
||||
|
||||
if ref_latent is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
if pose_video is not None:
|
||||
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
|
||||
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -1533,7 +1476,6 @@ class WanExtension(ComfyExtension):
|
||||
WanAnimateToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
WanInfiniteTalkToVideo,
|
||||
WanSCAILToVideo,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
|
||||
@ -2,6 +2,7 @@ import os
|
||||
import importlib.util
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import subprocess
|
||||
import re
|
||||
|
||||
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
|
||||
def get_gpu_names():
|
||||
@ -77,11 +78,24 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
def get_raw_cuda_version(version_str):
|
||||
match = re.search(r'\+cu(\d+)', version_str)
|
||||
if match:
|
||||
try:
|
||||
return int(match.group(1))
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
if not args.cuda_malloc:
|
||||
try:
|
||||
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
|
||||
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
|
||||
args.cuda_malloc = cuda_malloc_supported()
|
||||
cuda_version = get_raw_cuda_version(version)
|
||||
if cuda_version is not None and cuda_version >= 130:
|
||||
args.cuda_malloc = True
|
||||
else:
|
||||
args.cuda_malloc = cuda_malloc_supported()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
10
execution.py
10
execution.py
@ -40,6 +40,7 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_execution.asset_enrichment import enrich_output_with_assets
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io, _io
|
||||
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
||||
@ -199,6 +200,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
if io.Hidden.api_key_comfy_org.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
if io.Hidden.comfy_usage_source.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get("comfy_usage_source", None)
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
@ -215,6 +218,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
if h[x] == "COMFY_USAGE_SOURCE":
|
||||
input_data_all[x] = [extra_data.get("comfy_usage_source", None)]
|
||||
v3_data["hidden_inputs"] = hidden_inputs_v3
|
||||
return input_data_all, missing_keys, v3_data
|
||||
|
||||
@ -418,6 +423,7 @@ def _is_intermediate_output(dynprompt, node_id):
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
||||
|
||||
|
||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
||||
if server.client_id is None:
|
||||
return
|
||||
@ -552,6 +558,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
# Enrich at output-processing time (not in the send path) so assets
|
||||
# are registered even when no client is connected, and the asset id
|
||||
# flows into ui_outputs and the cache alongside the raw entries.
|
||||
output_ui = enrich_output_with_assets(output_ui)
|
||||
ui_outputs[unique_id] = {
|
||||
"meta": {
|
||||
"node_id": unique_id,
|
||||
|
||||
20
main.py
20
main.py
@ -26,6 +26,7 @@ import utils.extra_config
|
||||
from utils.mime_types import init_mime_types
|
||||
import faulthandler
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
@ -37,7 +38,19 @@ if __name__ == "__main__":
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
||||
faulthandler.enable(file=sys.stderr, all_threads=args.debug_hang)
|
||||
if __name__ == "__main__" and args.debug_hang:
|
||||
dumping_traceback = False
|
||||
|
||||
def dump_traceback_on_sigint(signum, frame):
|
||||
global dumping_traceback
|
||||
if dumping_traceback:
|
||||
raise KeyboardInterrupt
|
||||
dumping_traceback = True
|
||||
faulthandler.dump_traceback(file=sys.stderr, all_threads=True)
|
||||
raise KeyboardInterrupt
|
||||
|
||||
signal.signal(signal.SIGINT, dump_traceback_on_sigint)
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
@ -477,6 +490,11 @@ def start_comfyui(asyncio_loop=None):
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
))
|
||||
|
||||
# Re-apply Comfy's cuDNN benchmark policy after custom-node imports. Benchmark
|
||||
# mode can request near-card-sized autotune workspaces, and some custom nodes set it at import time.
|
||||
comfy.model_management.set_cudnn_benchmark()
|
||||
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
|
||||
5
nodes.py
5
nodes.py
@ -2406,6 +2406,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_video.py",
|
||||
"nodes_lumina2.py",
|
||||
"nodes_wan.py",
|
||||
"nodes_bernini.py",
|
||||
"nodes_lotus.py",
|
||||
"nodes_hunyuan3d.py",
|
||||
"nodes_primitive.py",
|
||||
@ -2452,6 +2453,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_rtdetr.py",
|
||||
"nodes_frame_interpolation.py",
|
||||
"nodes_sam3.py",
|
||||
"nodes_scail.py",
|
||||
"nodes_void.py",
|
||||
"nodes_wandancer.py",
|
||||
"nodes_hidream_o1.py",
|
||||
@ -2459,7 +2461,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_moge.py",
|
||||
"nodes_mediapipe.py",
|
||||
"nodes_gaussian_splat.py",
|
||||
"nodes_triposplat.py"
|
||||
"nodes_triposplat.py",
|
||||
"nodes_depth_anything_3.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
241
openapi.yaml
241
openapi.yaml
@ -3,11 +3,6 @@ components:
|
||||
Asset:
|
||||
description: Represents a user-owned asset (image, video, or other generated output).
|
||||
properties:
|
||||
asset_hash:
|
||||
deprecated: true
|
||||
description: 'Deprecated: use hash instead. Blake3 hash of the asset content.'
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
type: string
|
||||
created_at:
|
||||
description: Timestamp when the asset was created
|
||||
format: date-time
|
||||
@ -16,8 +11,12 @@ components:
|
||||
description: Display name of the asset. Mirrors name for backwards compatibility.
|
||||
nullable: true
|
||||
type: string
|
||||
file_path:
|
||||
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
|
||||
nullable: true
|
||||
type: string
|
||||
hash:
|
||||
description: Blake3 hash of the asset content. Preferred over asset_hash.
|
||||
description: Blake3 hash of the asset content.
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
type: string
|
||||
id:
|
||||
@ -139,17 +138,16 @@ components:
|
||||
AssetUpdated:
|
||||
description: Response returned when an existing asset is successfully updated.
|
||||
properties:
|
||||
asset_hash:
|
||||
deprecated: true
|
||||
description: 'Deprecated: use hash instead. Blake3 hash of the asset content.'
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
type: string
|
||||
display_name:
|
||||
description: Display name of the asset. Mirrors name for backwards compatibility.
|
||||
nullable: true
|
||||
type: string
|
||||
file_path:
|
||||
description: Relative path in global-namespace-root form (e.g. "models/checkpoints/flux.safetensors")
|
||||
nullable: true
|
||||
type: string
|
||||
hash:
|
||||
description: Blake3 hash of the asset content. Preferred over asset_hash.
|
||||
description: Blake3 hash of the asset content.
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
type: string
|
||||
id:
|
||||
@ -828,7 +826,11 @@ components:
|
||||
type: string
|
||||
type: object
|
||||
PaginationInfo:
|
||||
description: Offset/limit-based pagination metadata included in list responses.
|
||||
description: |
|
||||
Pagination metadata included in list responses. Supports both legacy
|
||||
offset/limit pagination and cursor-based pagination. When cursor-based
|
||||
pagination is used, `next_cursor` is the primary pagination token and
|
||||
`offset`/`total` may be zero.
|
||||
properties:
|
||||
has_more:
|
||||
description: Whether more items are available beyond this page
|
||||
@ -837,12 +839,19 @@ components:
|
||||
description: Items per page
|
||||
minimum: 1
|
||||
type: integer
|
||||
next_cursor:
|
||||
description: |
|
||||
Opaque cursor for the next page. Pass this value as the `after`
|
||||
query parameter on the next request. Empty or absent when there
|
||||
are no more results.
|
||||
type: string
|
||||
offset:
|
||||
description: Current offset (0-based)
|
||||
deprecated: true
|
||||
description: 'Current offset (0-based). Deprecated: use cursor-based pagination.'
|
||||
minimum: 0
|
||||
type: integer
|
||||
total:
|
||||
description: Total number of items matching filters
|
||||
description: Total number of items matching filters (may be 0 when using cursor pagination)
|
||||
minimum: 0
|
||||
type: integer
|
||||
required:
|
||||
@ -887,6 +896,11 @@ components:
|
||||
additionalProperties: true
|
||||
description: The workflow graph to execute
|
||||
type: object
|
||||
prompt_id:
|
||||
description: Optional client-supplied job id. Must be a UUID in canonical lowercase hyphenated form; it is echoed back in the response. Omitted or null means the server generates one.
|
||||
format: uuid
|
||||
nullable: true
|
||||
type: string
|
||||
workflow_id:
|
||||
description: UUID identifying the cloud workflow entity to associate with this job
|
||||
type: string
|
||||
@ -1053,6 +1067,9 @@ components:
|
||||
comfyui_version:
|
||||
description: ComfyUI version
|
||||
type: string
|
||||
deploy_environment:
|
||||
description: How this ComfyUI instance is deployed (e.g. cloud, local-git, local-portable, local-desktop)
|
||||
type: string
|
||||
embedded_python:
|
||||
description: Whether using embedded Python
|
||||
type: boolean
|
||||
@ -1518,17 +1535,11 @@ paths:
|
||||
schema:
|
||||
default: true
|
||||
type: boolean
|
||||
- description: Filter assets by exact content hash. Preferred over asset_hash.
|
||||
- description: Filter assets by exact content hash.
|
||||
in: query
|
||||
name: hash
|
||||
schema:
|
||||
type: string
|
||||
- deprecated: true
|
||||
description: 'Deprecated: use hash instead. Filter assets by exact content hash.'
|
||||
in: query
|
||||
name: asset_hash
|
||||
schema:
|
||||
type: string
|
||||
- description: |
|
||||
Opaque cursor for keyset pagination. Pass the `next_cursor` value
|
||||
from the previous response to fetch the next page. When provided,
|
||||
@ -1571,42 +1582,12 @@ paths:
|
||||
- file
|
||||
post:
|
||||
description: |
|
||||
Uploads a new asset to the system with associated metadata.
|
||||
Supports two upload methods:
|
||||
1. Direct file upload (multipart/form-data)
|
||||
2. URL-based upload (application/json with source: "url")
|
||||
Creates a new asset from a direct file upload (multipart/form-data) with associated metadata.
|
||||
|
||||
If an asset with the same hash already exists, returns the existing asset.
|
||||
operationId: uploadAsset
|
||||
operationId: createAsset
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
properties:
|
||||
name:
|
||||
description: Display name for the asset (used to determine file extension)
|
||||
type: string
|
||||
preview_id:
|
||||
description: Optional preview asset ID
|
||||
format: uuid
|
||||
type: string
|
||||
tags:
|
||||
description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
url:
|
||||
description: HTTP/HTTPS URL to download the asset from
|
||||
format: uri
|
||||
type: string
|
||||
user_metadata:
|
||||
additionalProperties: true
|
||||
description: Custom metadata to store with the asset
|
||||
type: object
|
||||
required:
|
||||
- url
|
||||
- name
|
||||
type: object
|
||||
multipart/form-data:
|
||||
schema:
|
||||
properties:
|
||||
@ -1614,6 +1595,10 @@ paths:
|
||||
description: The asset file to upload
|
||||
format: binary
|
||||
type: string
|
||||
hash:
|
||||
description: Content hash of the file.
|
||||
pattern: ^(blake3|sha256):[a-f0-9]{64}$
|
||||
type: string
|
||||
id:
|
||||
description: Optional asset ID for idempotent creation. If provided and asset exists, returns existing asset.
|
||||
format: uuid
|
||||
@ -1629,10 +1614,8 @@ paths:
|
||||
format: uuid
|
||||
type: string
|
||||
tags:
|
||||
description: Freeform tags for the asset. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
description: JSON-encoded array of freeform tag strings, e.g. '["models","checkpoint"]'. Common types include "models", "input", "output", and "temp", but any tag can be used in any order.
|
||||
type: string
|
||||
user_metadata:
|
||||
description: Custom JSON metadata as a string
|
||||
type: string
|
||||
@ -1641,36 +1624,32 @@ paths:
|
||||
type: object
|
||||
required: true
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AssetCreated'
|
||||
description: |
|
||||
Asset already existed for this user (deduplicated by content hash); the
|
||||
existing asset is returned with created_new=false.
|
||||
"201":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AssetCreated'
|
||||
description: Asset created successfully
|
||||
description: Asset created successfully (created_new=true)
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Invalid request (bad file, invalid URL, invalid content type, etc.)
|
||||
description: Invalid request (bad file, invalid content type, etc.)
|
||||
"401":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Unauthorized
|
||||
"403":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Source URL requires authentication or access denied
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Source URL not found
|
||||
"413":
|
||||
content:
|
||||
application/json:
|
||||
@ -1683,19 +1662,13 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Unsupported media type
|
||||
"422":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Download failed due to network error or timeout
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Internal server error
|
||||
summary: Upload a new asset
|
||||
summary: Create a new asset
|
||||
tags:
|
||||
- file
|
||||
/api/assets/{id}:
|
||||
@ -1730,7 +1703,7 @@ paths:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Asset cannot be deleted because it is referenced by another resource (e.g., workflow version)
|
||||
description: 'Asset cannot be deleted because it is referenced by another resource, e.g. a workflow version (error code: ASSET_IN_USE)'
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
@ -1783,7 +1756,7 @@ paths:
|
||||
description: |
|
||||
Updates an asset's metadata. At least one field must be provided.
|
||||
Only name, mime_type, preview_id, and user_metadata can be updated.
|
||||
For tag management, use the dedicated PUT /api/assets/{id}/tags endpoint.
|
||||
For tag management, use POST (add) and DELETE (remove) /api/assets/{id}/tags.
|
||||
operationId: updateAsset
|
||||
parameters:
|
||||
- description: Asset ID
|
||||
@ -1982,76 +1955,6 @@ paths:
|
||||
summary: Add tags to asset
|
||||
tags:
|
||||
- file
|
||||
put:
|
||||
description: Adds and removes tags from an asset in a single operation
|
||||
operationId: updateAssetTags
|
||||
parameters:
|
||||
- description: Asset ID
|
||||
in: path
|
||||
name: id
|
||||
required: true
|
||||
schema:
|
||||
format: uuid
|
||||
type: string
|
||||
requestBody:
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
description: At least one of add or remove must contain items. Empty arrays are allowed when the other array has items.
|
||||
minProperties: 1
|
||||
properties:
|
||||
add:
|
||||
description: Tags to add to the asset. Can be empty if remove has items.
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
remove:
|
||||
description: Tags to remove from the asset. Can be empty if add has items.
|
||||
items:
|
||||
type: string
|
||||
type: array
|
||||
type: object
|
||||
required: true
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/TagsModificationResponse'
|
||||
description: Tags updated successfully
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Invalid request
|
||||
"401":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Unauthorized
|
||||
"404":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Asset not found
|
||||
"422":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Reserved tag validation error
|
||||
"500":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Internal server error
|
||||
summary: Update asset tags
|
||||
tags:
|
||||
- file
|
||||
/api/assets/from-hash:
|
||||
post:
|
||||
description: |
|
||||
@ -2065,8 +1968,8 @@ paths:
|
||||
schema:
|
||||
properties:
|
||||
hash:
|
||||
description: Hash of the existing asset. Supports Blake3 (blake3:) or SHA256 (sha256:) formats
|
||||
pattern: ^(blake3|sha256):[a-f0-9]{64}$
|
||||
description: 'Blake3 content hash of the existing asset (blake3: prefix)'
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
type: string
|
||||
mime_type:
|
||||
description: MIME type of the asset (e.g., "image/png", "video/mp4")
|
||||
@ -2090,12 +1993,20 @@ paths:
|
||||
type: object
|
||||
required: true
|
||||
responses:
|
||||
"200":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AssetCreated'
|
||||
description: |
|
||||
Asset reference already existed for this user (deduplicated by content
|
||||
hash); the existing asset is returned with created_new=false.
|
||||
"201":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/AssetCreated'
|
||||
description: Asset reference created successfully
|
||||
description: Asset reference created successfully (created_new=true)
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
@ -2887,7 +2798,21 @@ paths:
|
||||
- asc
|
||||
- desc
|
||||
type: string
|
||||
- description: Pagination offset (0-based)
|
||||
- description: |
|
||||
Opaque cursor for keyset pagination. Pass the `next_cursor` value
|
||||
from a previous response to fetch the next page.
|
||||
Cursor pagination is supported only when `sort_by=create_time`
|
||||
(default). If `sort_by=execution_time`, `after` is ignored and
|
||||
offset/limit pagination is used.
|
||||
Cursors are opaque base64url payloads — clients should treat them
|
||||
as strings and not parse the contents.
|
||||
example: eyJzIjoiY3JlYXRlX3RpbWUiLCJ2IjoiMTcxNjIwMDAwMDAwMDAwMCIsImlkIjoiYTFiMmMzZDQtZTVmNi03YTg5LWIwYzEtZDJlM2Y0YTViNmM3In0
|
||||
in: query
|
||||
name: after
|
||||
schema:
|
||||
type: string
|
||||
- deprecated: true
|
||||
description: 'Pagination offset (0-based). Deprecated: prefer cursor-based pagination via `after`.'
|
||||
in: query
|
||||
name: offset
|
||||
schema:
|
||||
@ -2909,6 +2834,12 @@ paths:
|
||||
schema:
|
||||
$ref: '#/components/schemas/JobsListResponse'
|
||||
description: Success - Jobs retrieved
|
||||
"400":
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/ErrorResponse'
|
||||
description: Bad request (e.g. malformed pagination cursor).
|
||||
"401":
|
||||
content:
|
||||
application/json:
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
comfyui-frontend-package==1.44.19
|
||||
comfyui-workflow-templates==0.9.94
|
||||
comfyui-embedded-docs==0.5.2
|
||||
comfyui-frontend-package==1.45.15
|
||||
comfyui-workflow-templates==0.9.98
|
||||
comfyui-embedded-docs==0.5.3
|
||||
torch
|
||||
torchsde
|
||||
torchvision
|
||||
@ -23,7 +23,7 @@ SQLAlchemy>=2.0.0
|
||||
filelock
|
||||
av>=16.0.0
|
||||
comfy-kitchen==0.2.10
|
||||
comfy-aimdo==0.4.8
|
||||
comfy-aimdo==0.4.9
|
||||
requests
|
||||
simpleeval>=1.0.0
|
||||
blake3
|
||||
|
||||
34
server.py
34
server.py
@ -8,7 +8,7 @@ import time
|
||||
import nodes
|
||||
import folder_paths
|
||||
import execution
|
||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id
|
||||
import uuid
|
||||
import urllib
|
||||
import json
|
||||
@ -27,6 +27,7 @@ import logging
|
||||
|
||||
import mimetypes
|
||||
from comfy.cli_args import args
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy_api import feature_flags
|
||||
@ -690,6 +691,7 @@ class PromptServer():
|
||||
"python_version": sys.version,
|
||||
"pytorch_version": comfy.model_management.torch_version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||
"deploy_environment": get_deploy_environment(),
|
||||
"argv": sys.argv
|
||||
},
|
||||
"devices": device_entries
|
||||
@ -942,7 +944,21 @@ class PromptServer():
|
||||
|
||||
if "prompt" in json_data:
|
||||
prompt = json_data["prompt"]
|
||||
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
||||
client_prompt_id = json_data.get("prompt_id")
|
||||
if client_prompt_id is None:
|
||||
# Absent or explicit null: the server mints the id.
|
||||
prompt_id = str(uuid.uuid4())
|
||||
else:
|
||||
try:
|
||||
prompt_id = validate_job_id(client_prompt_id)
|
||||
except ValueError:
|
||||
error = {
|
||||
"type": "invalid_prompt_id",
|
||||
"message": "prompt_id must be a valid UUID",
|
||||
"details": "prompt_id must be a UUID string in canonical lowercase hyphenated form; omit it to let the server generate one",
|
||||
"extra_info": {}
|
||||
}
|
||||
return web.json_response({"error": error, "node_errors": {}}, status=400)
|
||||
|
||||
partial_execution_targets = None
|
||||
if "partial_execution_targets" in json_data:
|
||||
@ -957,6 +973,11 @@ class PromptServer():
|
||||
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
|
||||
if "comfy_usage_source" not in extra_data:
|
||||
usage_source = request.headers.get("Comfy-Usage-Source")
|
||||
if usage_source:
|
||||
extra_data["comfy_usage_source"] = usage_source
|
||||
if valid[0]:
|
||||
outputs_to_execute = valid[2]
|
||||
sensitive = {}
|
||||
@ -1253,6 +1274,15 @@ class PromptServer():
|
||||
|
||||
if verbose:
|
||||
logging.info("Starting server\n")
|
||||
if args.debug_hang:
|
||||
logging.info(
|
||||
f"{'-' * 80}\n"
|
||||
"ComfyUI has been started in debug-hang mode. Run your workflow as normal up to\n"
|
||||
"the point of the hang or freeze, then use ctrl-C in the cmd or controlling\n"
|
||||
"terminal to dump the python backtraces for debugging. Please attach the extra\n"
|
||||
"debug info to your bug report.\n"
|
||||
f"{'-' * 80}"
|
||||
)
|
||||
for addr in addresses:
|
||||
address = addr[0]
|
||||
port = addr[1]
|
||||
|
||||
@ -6,6 +6,7 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, Optional
|
||||
|
||||
@ -188,9 +189,17 @@ def _post_multipart_asset(
|
||||
|
||||
@pytest.fixture
|
||||
def make_asset_bytes() -> Callable[[str, int], bytes]:
|
||||
# Salt content per test so it never collides with assets left over from
|
||||
# earlier tests. Delete is now always a soft delete (content is preserved),
|
||||
# so the suite can no longer rely on hard-deleting content for isolation.
|
||||
# Deterministic within a test: the same (name, size) yields the same bytes.
|
||||
salt = uuid.uuid4().bytes
|
||||
|
||||
def _make(name: str, size: int = 8192) -> bytes:
|
||||
seed = sum(ord(c) for c in name) % 251
|
||||
return bytes((i * 31 + seed) % 256 for i in range(size))
|
||||
body = bytearray((i * 31 + seed) % 256 for i in range(size))
|
||||
body[: len(salt)] = salt[:size]
|
||||
return bytes(body)
|
||||
return _make
|
||||
|
||||
|
||||
@ -212,7 +221,7 @@ def asset_factory(http: requests.Session, api_base: str):
|
||||
|
||||
for aid in created:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30)
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -227,7 +236,11 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas
|
||||
if tags is None:
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
|
||||
files = {"file": (name, b"A" * 4096, "application/octet-stream")}
|
||||
# Unique content per test so the seed always creates a fresh asset (201).
|
||||
# Delete is now always a soft delete, so content from a prior test survives
|
||||
# and would otherwise dedup this upload into an existing asset (200).
|
||||
content = uuid.uuid4().bytes + b"A" * (4096 - 16)
|
||||
files = {"file": (name, content, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
@ -260,4 +273,4 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str):
|
||||
break
|
||||
for aid in ids:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30)
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
112
tests-unit/assets_test/queries/test_asset_reference_keyset.py
Normal file
112
tests-unit/assets_test/queries/test_asset_reference_keyset.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Keyset-pagination tiebreaker tests for list_references_page.
|
||||
|
||||
When multiple rows share the same primary sort value (e.g. four assets
|
||||
created in the same microsecond), the secondary `ORDER BY id` is what keeps
|
||||
keyset pagination from losing or repeating rows. This file exercises that
|
||||
branch directly against an in-memory SQLite session — engineering identical
|
||||
timestamps via HTTP is unreliable enough that we work at the query layer.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.queries.asset_reference import list_references_page
|
||||
|
||||
|
||||
def _make_ref(session: Session, created_at: datetime, name: str, owner: str = "") -> AssetReference:
|
||||
asset = Asset(hash=f"blake3:{uuid.uuid4().hex}", size_bytes=1024)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
ref = AssetReference(
|
||||
id=str(uuid.uuid4()),
|
||||
asset_id=asset.id,
|
||||
owner_id=owner,
|
||||
name=name,
|
||||
file_path=f"/tmp/{name}",
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
last_access_time=created_at,
|
||||
is_missing=False,
|
||||
)
|
||||
session.add(ref)
|
||||
return ref
|
||||
|
||||
|
||||
@pytest.mark.parametrize("order", ["desc", "asc"])
|
||||
def test_tiebreaker_walks_duplicate_sort_values(session: Session, order: str):
|
||||
"""Four rows with the SAME created_at must paginate cleanly under cursor
|
||||
mode — no row dropped, no row repeated, despite the primary sort column
|
||||
being non-discriminating.
|
||||
"""
|
||||
shared_ts = datetime(2024, 5, 20, 12, 0, 0) # naive UTC, like the DB stores
|
||||
refs = [_make_ref(session, shared_ts, f"tie_{i}.png") for i in range(4)]
|
||||
session.commit()
|
||||
|
||||
expected_ids = sorted([r.id for r in refs], reverse=(order == "desc"))
|
||||
|
||||
# Walk the cursor by hand: page size 2, take 3 pages (2 + 2 + 0).
|
||||
seen: list[str] = []
|
||||
after_value = None
|
||||
after_id = None
|
||||
for _ in range(4): # generous loop bound; ought to be 2 iterations
|
||||
page, _tag_map, _total = list_references_page(
|
||||
session,
|
||||
limit=2,
|
||||
sort="created_at",
|
||||
order=order,
|
||||
after_cursor_value=after_value,
|
||||
after_cursor_id=after_id,
|
||||
)
|
||||
if not page:
|
||||
break
|
||||
seen.extend(p.id for p in page)
|
||||
# Use the last row's (created_at, id) as the next cursor input.
|
||||
last = page[-1]
|
||||
after_value, after_id = last.created_at, last.id
|
||||
if len(page) < 2:
|
||||
break
|
||||
|
||||
assert seen == expected_ids, (
|
||||
f"keyset tiebreaker failed for order={order}: expected {expected_ids}, got {seen}"
|
||||
)
|
||||
|
||||
|
||||
def test_tiebreaker_no_duplicates_under_mixed_collisions(session: Session):
|
||||
"""Some rows share a timestamp, some don't. The cursor must still walk
|
||||
every row exactly once regardless of where ties sit relative to a
|
||||
page boundary."""
|
||||
t1 = datetime(2024, 5, 20, 12, 0, 0)
|
||||
t2 = datetime(2024, 5, 20, 12, 0, 1)
|
||||
layout = [t1, t1, t1, t2, t2] # three rows at t1, two at t2
|
||||
refs = [_make_ref(session, ts, f"mix_{i}.png") for i, ts in enumerate(layout)]
|
||||
session.commit()
|
||||
|
||||
all_ids = {r.id for r in refs}
|
||||
seen_set: set[str] = set()
|
||||
seen_list: list[str] = []
|
||||
after_value = None
|
||||
after_id = None
|
||||
for _ in range(6):
|
||||
page, _, _ = list_references_page(
|
||||
session,
|
||||
limit=2,
|
||||
sort="created_at",
|
||||
order="desc",
|
||||
after_cursor_value=after_value,
|
||||
after_cursor_id=after_id,
|
||||
)
|
||||
if not page:
|
||||
break
|
||||
for p in page:
|
||||
assert p.id not in seen_set, f"duplicate row {p.id} appeared in cursor walk"
|
||||
seen_set.add(p.id)
|
||||
seen_list.append(p.id)
|
||||
last = page[-1]
|
||||
after_value, after_id = last.created_at, last.id
|
||||
if len(page) < 2:
|
||||
break
|
||||
|
||||
assert seen_set == all_ids, f"missing rows: expected {all_ids}, got {seen_set}"
|
||||
@ -40,15 +40,15 @@ def _make_reference(session: Session, asset: Asset, name: str = "test", owner_id
|
||||
|
||||
class TestEnsureTagsExist:
|
||||
def test_creates_new_tags(self, session: Session):
|
||||
ensure_tags_exist(session, ["alpha", "beta"], tag_type="user")
|
||||
ensure_tags_exist(session, ["alpha", "beta"])
|
||||
session.commit()
|
||||
|
||||
tags = session.query(Tag).all()
|
||||
assert {t.name for t in tags} == {"alpha", "beta"}
|
||||
|
||||
def test_is_idempotent(self, session: Session):
|
||||
ensure_tags_exist(session, ["alpha"], tag_type="user")
|
||||
ensure_tags_exist(session, ["alpha"], tag_type="user")
|
||||
ensure_tags_exist(session, ["alpha"])
|
||||
ensure_tags_exist(session, ["alpha"])
|
||||
session.commit()
|
||||
|
||||
assert session.query(Tag).count() == 1
|
||||
@ -65,13 +65,6 @@ class TestEnsureTagsExist:
|
||||
session.commit()
|
||||
assert session.query(Tag).count() == 0
|
||||
|
||||
def test_tag_type_is_set(self, session: Session):
|
||||
ensure_tags_exist(session, ["system-tag"], tag_type="system")
|
||||
session.commit()
|
||||
|
||||
tag = session.query(Tag).filter_by(name="system-tag").one()
|
||||
assert tag.tag_type == "system"
|
||||
|
||||
|
||||
class TestGetReferenceTags:
|
||||
def test_returns_empty_for_no_tags(self, session: Session):
|
||||
@ -193,7 +186,7 @@ class TestMissingTagFunctions:
|
||||
def test_add_missing_tag_for_asset_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
ensure_tags_exist(session, ["missing"])
|
||||
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
session.commit()
|
||||
@ -204,7 +197,7 @@ class TestMissingTagFunctions:
|
||||
def test_add_missing_tag_is_idempotent(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
ensure_tags_exist(session, ["missing"])
|
||||
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
@ -216,7 +209,7 @@ class TestMissingTagFunctions:
|
||||
def test_remove_missing_tag_for_asset_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
ensure_tags_exist(session, ["missing"])
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
@ -237,7 +230,7 @@ class TestListTagsWithUsage:
|
||||
|
||||
rows, total = list_tags_with_usage(session)
|
||||
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
tag_dict = {name: count for name, count in rows}
|
||||
assert tag_dict["used"] == 1
|
||||
assert tag_dict["unused"] == 0
|
||||
assert total == 2
|
||||
@ -252,7 +245,7 @@ class TestListTagsWithUsage:
|
||||
|
||||
rows, total = list_tags_with_usage(session, include_zero=False)
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
tag_names = {name for name, _ in rows}
|
||||
assert "used" in tag_names
|
||||
assert "unused" not in tag_names
|
||||
|
||||
@ -262,7 +255,7 @@ class TestListTagsWithUsage:
|
||||
|
||||
rows, total = list_tags_with_usage(session, prefix="alph")
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
tag_names = {name for name, _ in rows}
|
||||
assert tag_names == {"alpha", "alphabet"}
|
||||
|
||||
def test_order_by_name(self, session: Session):
|
||||
@ -271,7 +264,7 @@ class TestListTagsWithUsage:
|
||||
|
||||
rows, _ = list_tags_with_usage(session, order="name_asc")
|
||||
|
||||
names = [name for name, _, _ in rows]
|
||||
names = [name for name, _ in rows]
|
||||
assert names == ["alpha", "middle", "zebra"]
|
||||
|
||||
def test_owner_visibility(self, session: Session):
|
||||
@ -287,13 +280,13 @@ class TestListTagsWithUsage:
|
||||
|
||||
# Empty owner sees only shared
|
||||
rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False)
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
tag_dict = {name: count for name, count in rows}
|
||||
assert tag_dict.get("shared-tag", 0) == 1
|
||||
assert tag_dict.get("owner-tag", 0) == 0
|
||||
|
||||
# User1 sees both
|
||||
rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False)
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
tag_dict = {name: count for name, count in rows}
|
||||
assert tag_dict.get("shared-tag", 0) == 1
|
||||
assert tag_dict.get("owner-tag", 0) == 1
|
||||
|
||||
|
||||
278
tests-unit/assets_test/services/test_cursor.py
Normal file
278
tests-unit/assets_test/services/test_cursor.py
Normal file
@ -0,0 +1,278 @@
|
||||
"""Tests for app.assets.services.cursor.
|
||||
|
||||
Cursors are opaque tokens internal to this server — these tests cover
|
||||
round-tripping, validation, and length caps, not any particular wire
|
||||
byte layout.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.cursor import (
|
||||
MAX_CURSOR_ID_LENGTH,
|
||||
MAX_CURSOR_VALUE_LENGTH,
|
||||
MAX_ENCODED_CURSOR_LENGTH,
|
||||
CursorPayload,
|
||||
InvalidCursorError,
|
||||
decode_cursor,
|
||||
decode_cursor_int,
|
||||
decode_cursor_time,
|
||||
encode_cursor,
|
||||
encode_cursor_from_time,
|
||||
)
|
||||
|
||||
|
||||
ALLOWED = ("created_at", "updated_at", "name", "size")
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
@pytest.mark.parametrize(
|
||||
"sort_field, value, id",
|
||||
[
|
||||
("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"),
|
||||
("size", "1024", "asset-123"),
|
||||
("name", "my-asset.png", "asset-abc"),
|
||||
("name", "résumé.txt", "asset-uni"),
|
||||
("name", "foo<&>bar.png", "asset-html"),
|
||||
("name", 'quo"te\\back\nnewline.png', "asset-esc"),
|
||||
],
|
||||
)
|
||||
def test_encode_decode(self, sort_field, value, id):
|
||||
encoded = encode_cursor(sort_field, value, id)
|
||||
assert encoded != ""
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.sort_field == sort_field
|
||||
assert payload.value == value
|
||||
assert payload.id == id
|
||||
|
||||
|
||||
class TestTimeCursor:
|
||||
def test_microsecond_precision_preserved(self):
|
||||
# Pick a time with non-zero microseconds — encoding at ms would lose the µs.
|
||||
ts = datetime(2024, 5, 20, 12, 53, 20, 123456, tzinfo=timezone.utc)
|
||||
encoded = encode_cursor_from_time("created_at", ts, "id-1")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
# Value must be a microsecond integer string, not a millisecond one.
|
||||
assert payload.value == "1716209600123456"
|
||||
decoded = decode_cursor_time(payload)
|
||||
assert decoded == ts
|
||||
|
||||
def test_decode_returns_utc(self):
|
||||
payload = CursorPayload(sort_field="created_at", value="1716200000123456", id="id-1", order="desc")
|
||||
decoded = decode_cursor_time(payload)
|
||||
assert decoded.tzinfo == timezone.utc
|
||||
|
||||
def test_naive_datetime_rejected_on_encode(self):
|
||||
naive = datetime(2024, 5, 20, 12, 0, 0)
|
||||
with pytest.raises(ValueError):
|
||||
encode_cursor_from_time("created_at", naive, "id-1")
|
||||
|
||||
def test_non_integer_value_rejected_on_decode(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(CursorPayload("created_at", "not-a-number", "id-1", "desc"))
|
||||
|
||||
def test_none_payload_rejected(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(None)
|
||||
|
||||
def test_non_utc_aware_normalized(self):
|
||||
# Same instant, different timezone — must encode to the same micros.
|
||||
utc_ts = datetime(2024, 5, 20, 12, 0, 0, tzinfo=timezone.utc)
|
||||
offset_ts = utc_ts.astimezone(timezone(timedelta(hours=-5)))
|
||||
assert encode_cursor_from_time("created_at", utc_ts, "x") == encode_cursor_from_time(
|
||||
"created_at", offset_ts, "x"
|
||||
)
|
||||
|
||||
|
||||
class TestIntCursor:
|
||||
def test_decode_int(self):
|
||||
assert decode_cursor_int(CursorPayload("size", "1024", "id-1", "desc")) == 1024
|
||||
|
||||
def test_decode_int_rejects_non_int(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_int(CursorPayload("size", "abc", "id-1", "desc"))
|
||||
|
||||
def test_decode_int_rejects_none(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_int(None)
|
||||
|
||||
|
||||
class TestInvalidInputs:
|
||||
def test_oversized_cursor(self):
|
||||
oversized = "a" * (MAX_ENCODED_CURSOR_LENGTH + 1)
|
||||
with pytest.raises(InvalidCursorError, match="maximum length"):
|
||||
decode_cursor(oversized, ALLOWED)
|
||||
|
||||
def test_not_base64(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor("not base64!!!", ALLOWED)
|
||||
|
||||
def test_not_json(self):
|
||||
encoded = base64.urlsafe_b64encode(b"definitely not json").rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_empty_id(self):
|
||||
# Encoder rejects empty id symmetrically with the decoder, so build the
|
||||
# payload manually to exercise the decoder's missing-id branch.
|
||||
raw = b'{"s":"created_at","v":"1","id":"","o":"desc"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="missing id"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_oversized_id(self):
|
||||
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
|
||||
big_id = "a" * (MAX_CURSOR_ID_LENGTH + 1)
|
||||
raw = ('{"s":"created_at","v":"1","id":"' + big_id + '","o":"desc"}').encode("ascii")
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_oversized_value(self):
|
||||
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
|
||||
big_v = "v" * (MAX_CURSOR_VALUE_LENGTH + 1)
|
||||
raw = ('{"s":"created_at","v":"' + big_v + '","id":"id-1","o":"desc"}').encode("ascii")
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_unsupported_sort_field(self):
|
||||
encoded = encode_cursor("execution_time", "1", "id-1")
|
||||
with pytest.raises(InvalidCursorError, match="unsupported sort field"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_no_allowed_fields_rejects_everything(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1")
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor(encoded, ())
|
||||
|
||||
def test_non_dict_payload_rejected(self):
|
||||
encoded = base64.urlsafe_b64encode(b'["array","not","dict"]').rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="expected object"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
|
||||
class TestEncodeAtCapsFits:
|
||||
def test_max_field_lengths_fit_wire_cap(self):
|
||||
# Worst-case payload: value and id at their per-field caps, with a long
|
||||
# sort field name. The encoded cursor must fit within MAX_ENCODED_CURSOR_LENGTH
|
||||
# so the wire cap cannot reject a cursor the encoder mints at the per-field caps.
|
||||
value = "v" * MAX_CURSOR_VALUE_LENGTH
|
||||
id = "i" * MAX_CURSOR_ID_LENGTH
|
||||
sort_field = "very_long_sort_field_name"
|
||||
|
||||
encoded = encode_cursor(sort_field, value, id)
|
||||
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
|
||||
payload = decode_cursor(encoded, (sort_field,))
|
||||
assert payload.value == value
|
||||
assert payload.id == id
|
||||
|
||||
|
||||
class TestDatetimeOverflow:
|
||||
"""Crafted cursors with extreme micros must map to InvalidCursorError,
|
||||
not OverflowError/OSError leaking as 500.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"micros_str",
|
||||
[
|
||||
"999999999999999999999", # 10^21 µs — past datetime.MAX_YEAR by ~14 orders
|
||||
"-999999999999999999999", # symmetric negative — pre-epoch overflow
|
||||
],
|
||||
)
|
||||
def test_out_of_range_micros_rejected(self, micros_str):
|
||||
encoded = encode_cursor("created_at", micros_str, "asset-x")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(payload)
|
||||
|
||||
|
||||
class TestEncoderDecoderSymmetry:
|
||||
"""The encoder must never mint a cursor the decoder would reject, or the
|
||||
same server would 400 on a cursor it just handed out. Per-field caps keep
|
||||
the encoded length below the wire cap, so a freshly minted cursor always
|
||||
round-trips.
|
||||
"""
|
||||
|
||||
def test_long_name_within_cap_round_trips(self):
|
||||
"""Assets allow names up to 512 chars (`String(512)`); the cursor
|
||||
encoder must round-trip a value at that cap so a freshly minted
|
||||
cursor never fails decode on the next request."""
|
||||
long_name = "n" * MAX_CURSOR_VALUE_LENGTH
|
||||
encoded = encode_cursor("name", long_name, "asset-x")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.value == long_name
|
||||
|
||||
def test_encoder_rejects_empty_id(self):
|
||||
with pytest.raises(InvalidCursorError, match="id must be non-empty"):
|
||||
encode_cursor("created_at", "1", "")
|
||||
|
||||
def test_encoder_rejects_oversized_id(self):
|
||||
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
|
||||
encode_cursor("created_at", "1", "a" * (MAX_CURSOR_ID_LENGTH + 1))
|
||||
|
||||
def test_encoder_rejects_oversized_value(self):
|
||||
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
|
||||
encode_cursor("name", "v" * (MAX_CURSOR_VALUE_LENGTH + 1), "id-1")
|
||||
|
||||
def test_multibyte_value_at_cap_round_trips(self):
|
||||
"""A value at the char-count cap made of multibyte characters
|
||||
(e.g. 'é' = 2 UTF-8 bytes) stays under the wire cap, so it mints and
|
||||
round-trips — the per-field caps, not a mint-time length check, are
|
||||
what bound cursor size."""
|
||||
value = "é" * MAX_CURSOR_VALUE_LENGTH
|
||||
encoded = encode_cursor("name", value, "asset-multibyte")
|
||||
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.value == value
|
||||
|
||||
def test_escape_heavy_value_at_cap_round_trips(self):
|
||||
"""JSON escape expansion is the worst case: each control character
|
||||
serializes to the six-byte `\\uXXXX` form. A value of 512 of them is
|
||||
the largest a cursor can get, and it still fits the wire cap, mints,
|
||||
and round-trips."""
|
||||
value = "\x01" * MAX_CURSOR_VALUE_LENGTH
|
||||
encoded = encode_cursor("name", value, "asset-escape")
|
||||
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.value == value
|
||||
|
||||
|
||||
class TestOrderBinding:
|
||||
def test_order_baked_into_payload(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="asc")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.order == "asc"
|
||||
|
||||
def test_mismatched_order_rejected(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
|
||||
with pytest.raises(InvalidCursorError, match="does not match request order"):
|
||||
decode_cursor(encoded, ALLOWED, expected_order="asc")
|
||||
|
||||
def test_matching_order_accepted(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
|
||||
payload = decode_cursor(encoded, ALLOWED, expected_order="desc")
|
||||
assert payload.order == "desc"
|
||||
|
||||
def test_invalid_order_token_rejected_on_encode(self):
|
||||
with pytest.raises(ValueError):
|
||||
encode_cursor("created_at", "1", "id-1", order="sideways")
|
||||
|
||||
def test_invalid_order_token_rejected_on_decode(self):
|
||||
# Hand-craft a payload with an illegal `o` value.
|
||||
raw = b'{"s":"name","v":"x","id":"id-1","o":"sideways"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="unsupported order"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_cursor_without_order_rejected(self):
|
||||
"""`o` is mandatory. A cursor minted without it is rejected as
|
||||
malformed rather than silently walking the keyset in whatever
|
||||
direction the request happens to ask for."""
|
||||
raw = b'{"s":"name","v":"x","id":"id-1"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="missing or non-string o"):
|
||||
decode_cursor(encoded, ALLOWED, expected_order="desc")
|
||||
86
tests-unit/assets_test/services/test_image_dimensions.py
Normal file
86
tests-unit/assets_test/services/test_image_dimensions.py
Normal file
@ -0,0 +1,86 @@
|
||||
"""Tests for the image_dimensions service."""
|
||||
from __future__ import annotations
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
|
||||
from app.assets.services.image_dimensions import extract_image_dimensions
|
||||
|
||||
|
||||
def _make_png(path: Path, size: tuple[int, int]) -> Path:
|
||||
img = Image.new("RGB", size, color=(123, 45, 67))
|
||||
img.save(path, format="PNG")
|
||||
return path
|
||||
|
||||
|
||||
def _make_jpeg(path: Path, size: tuple[int, int]) -> Path:
|
||||
img = Image.new("RGB", size, color=(10, 20, 30))
|
||||
img.save(path, format="JPEG", quality=80)
|
||||
return path
|
||||
|
||||
|
||||
class TestExtractImageDimensions:
|
||||
def test_extracts_png_dimensions(self, tmp_path: Path):
|
||||
f = _make_png(tmp_path / "rect.png", (320, 240))
|
||||
|
||||
result = extract_image_dimensions(str(f), mime_type="image/png")
|
||||
|
||||
assert result == {"kind": "image", "width": 320, "height": 240}
|
||||
|
||||
def test_extracts_jpeg_dimensions(self, tmp_path: Path):
|
||||
f = _make_jpeg(tmp_path / "shot.jpg", (1920, 1080))
|
||||
|
||||
result = extract_image_dimensions(str(f), mime_type="image/jpeg")
|
||||
|
||||
assert result == {"kind": "image", "width": 1920, "height": 1080}
|
||||
|
||||
def test_works_when_mime_type_is_none(self, tmp_path: Path):
|
||||
f = _make_png(tmp_path / "no_mime.png", (50, 100))
|
||||
|
||||
result = extract_image_dimensions(str(f), mime_type=None)
|
||||
|
||||
assert result == {"kind": "image", "width": 50, "height": 100}
|
||||
|
||||
def test_skips_non_image_mime_without_touching_file(self, tmp_path: Path):
|
||||
# Path doesn't need to exist — non-image MIME short-circuits.
|
||||
result = extract_image_dimensions(
|
||||
str(tmp_path / "model.safetensors"),
|
||||
mime_type="application/octet-stream",
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"mime",
|
||||
["application/json", "text/plain", "video/mp4", "audio/mpeg"],
|
||||
)
|
||||
def test_skips_all_non_image_mime_types(self, tmp_path: Path, mime: str):
|
||||
f = tmp_path / "file.bin"
|
||||
f.write_bytes(b"\x00\x01\x02")
|
||||
|
||||
assert extract_image_dimensions(str(f), mime_type=mime) is None
|
||||
|
||||
def test_returns_none_for_missing_file(self, tmp_path: Path):
|
||||
result = extract_image_dimensions(
|
||||
str(tmp_path / "does_not_exist.png"), mime_type="image/png"
|
||||
)
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_corrupt_image(self, tmp_path: Path):
|
||||
f = tmp_path / "corrupt.png"
|
||||
f.write_bytes(b"not actually a png file")
|
||||
|
||||
result = extract_image_dimensions(str(f), mime_type="image/png")
|
||||
|
||||
assert result is None
|
||||
|
||||
def test_returns_none_for_empty_file(self, tmp_path: Path):
|
||||
f = tmp_path / "empty.png"
|
||||
f.write_bytes(b"")
|
||||
|
||||
result = extract_image_dimensions(str(f), mime_type="image/png")
|
||||
|
||||
assert result is None
|
||||
@ -4,10 +4,12 @@ from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from PIL import Image
|
||||
from sqlalchemy.orm import Session as SASession, Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference, AssetReferenceTag, Tag
|
||||
from app.assets.database.queries import get_reference_tags
|
||||
from app.assets.helpers import get_utc_now
|
||||
from app.assets.services.ingest import (
|
||||
_ingest_file_from_path,
|
||||
_register_existing_asset,
|
||||
@ -15,6 +17,11 @@ from app.assets.services.ingest import (
|
||||
)
|
||||
|
||||
|
||||
def _make_png(path: Path, size: tuple[int, int]) -> Path:
|
||||
Image.new("RGB", size, color=(80, 120, 200)).save(path, format="PNG")
|
||||
return path
|
||||
|
||||
|
||||
class TestIngestFileFromPath:
|
||||
def test_creates_asset_and_reference(self, mock_create_session, temp_dir: Path, session: Session):
|
||||
file_path = temp_dir / "test_file.bin"
|
||||
@ -279,4 +286,203 @@ class TestIngestExistingFileTagFK:
|
||||
ref_tags = sess.query(AssetReferenceTag).all()
|
||||
ref_tag_names = {rt.tag_name for rt in ref_tags}
|
||||
assert "output" in ref_tag_names
|
||||
assert "my-job" in ref_tag_names
|
||||
|
||||
|
||||
class TestIngestImageDimensions:
|
||||
"""system_metadata should carry {kind, width, height} for image assets."""
|
||||
|
||||
def test_image_asset_emits_dimensions(
|
||||
self, mock_create_session, temp_dir: Path, session: Session
|
||||
):
|
||||
f = _make_png(temp_dir / "shot.png", (640, 480))
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(f),
|
||||
asset_hash="blake3:img1",
|
||||
size_bytes=f.stat().st_size,
|
||||
mtime_ns=1234567890000000000,
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
|
||||
assert ref.system_metadata == {
|
||||
"kind": "image",
|
||||
"width": 640,
|
||||
"height": 480,
|
||||
}
|
||||
|
||||
def test_non_image_asset_leaves_system_metadata_empty(
|
||||
self, mock_create_session, temp_dir: Path, session: Session
|
||||
):
|
||||
f = temp_dir / "model.safetensors"
|
||||
f.write_bytes(b"not an image")
|
||||
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(f),
|
||||
asset_hash="blake3:safetensors1",
|
||||
size_bytes=f.stat().st_size,
|
||||
mtime_ns=1234567890000000000,
|
||||
mime_type="application/octet-stream",
|
||||
)
|
||||
|
||||
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
|
||||
assert ref.system_metadata in (None, {})
|
||||
|
||||
def test_preserves_existing_system_metadata_keys(
|
||||
self, mock_create_session, temp_dir: Path, session: Session
|
||||
):
|
||||
f = _make_png(temp_dir / "annotated.png", (100, 200))
|
||||
|
||||
# First pass populates a sentinel system_metadata key (simulating prior
|
||||
# enricher write).
|
||||
result = _ingest_file_from_path(
|
||||
abs_path=str(f),
|
||||
asset_hash="blake3:img-merge",
|
||||
size_bytes=f.stat().st_size,
|
||||
mtime_ns=1234567890000000000,
|
||||
mime_type="image/png",
|
||||
)
|
||||
ref = session.query(AssetReference).filter_by(id=result.reference_id).first()
|
||||
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/x.png"}
|
||||
session.commit()
|
||||
|
||||
# Second pass with the same path triggers the merge code path again.
|
||||
_ingest_file_from_path(
|
||||
abs_path=str(f),
|
||||
asset_hash="blake3:img-merge",
|
||||
size_bytes=f.stat().st_size,
|
||||
mtime_ns=1234567890000000001,
|
||||
mime_type="image/png",
|
||||
)
|
||||
|
||||
session.refresh(ref)
|
||||
assert ref.system_metadata["kind"] == "image"
|
||||
assert ref.system_metadata["width"] == 100
|
||||
assert ref.system_metadata["height"] == 200
|
||||
assert ref.system_metadata["source_url"] == "https://example/x.png"
|
||||
|
||||
|
||||
class TestRegisterExistingAssetBackfill:
|
||||
"""The from-hash path back-fills dimensions from a sibling reference."""
|
||||
|
||||
def _add_reference(
|
||||
self,
|
||||
session: Session,
|
||||
asset: Asset,
|
||||
name: str,
|
||||
system_metadata: dict | None = None,
|
||||
) -> AssetReference:
|
||||
now = get_utc_now()
|
||||
ref = AssetReference(
|
||||
asset_id=asset.id,
|
||||
name=name,
|
||||
owner_id="",
|
||||
created_at=now,
|
||||
updated_at=now,
|
||||
last_access_time=now,
|
||||
system_metadata=system_metadata or {},
|
||||
)
|
||||
session.add(ref)
|
||||
session.flush()
|
||||
return ref
|
||||
|
||||
def test_backfills_dimensions_from_sibling_image_reference(
|
||||
self, mock_create_session, session: Session
|
||||
):
|
||||
asset = Asset(hash="blake3:shared", size_bytes=2048, mime_type="image/png")
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
self._add_reference(
|
||||
session,
|
||||
asset,
|
||||
name="original.png",
|
||||
system_metadata={"kind": "image", "width": 800, "height": 600},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash="blake3:shared",
|
||||
name="from_hash.png",
|
||||
owner_id="user-x",
|
||||
)
|
||||
|
||||
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
|
||||
assert ref.system_metadata.get("kind") == "image"
|
||||
assert ref.system_metadata.get("width") == 800
|
||||
assert ref.system_metadata.get("height") == 600
|
||||
|
||||
def test_no_backfill_when_sibling_has_no_image_metadata(
|
||||
self, mock_create_session, session: Session
|
||||
):
|
||||
asset = Asset(hash="blake3:nodims", size_bytes=2048, mime_type="image/png")
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
self._add_reference(
|
||||
session,
|
||||
asset,
|
||||
name="original.png",
|
||||
system_metadata={"base_model": "flux"}, # no kind=image
|
||||
)
|
||||
session.commit()
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash="blake3:nodims",
|
||||
name="from_hash.png",
|
||||
owner_id="user-x",
|
||||
)
|
||||
|
||||
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
|
||||
meta = ref.system_metadata or {}
|
||||
assert "kind" not in meta
|
||||
assert "width" not in meta
|
||||
assert "height" not in meta
|
||||
|
||||
def test_no_backfill_when_no_sibling_exists(
|
||||
self, mock_create_session, session: Session
|
||||
):
|
||||
asset = Asset(hash="blake3:lonely", size_bytes=1024, mime_type="image/png")
|
||||
session.add(asset)
|
||||
session.commit()
|
||||
|
||||
result = _register_existing_asset(
|
||||
asset_hash="blake3:lonely",
|
||||
name="solo.png",
|
||||
owner_id="user-x",
|
||||
)
|
||||
|
||||
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
|
||||
assert ref.system_metadata in (None, {})
|
||||
|
||||
def test_backfill_preserves_caller_supplied_keys(
|
||||
self, mock_create_session, session: Session
|
||||
):
|
||||
asset = Asset(hash="blake3:preserve", size_bytes=2048, mime_type="image/png")
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
self._add_reference(
|
||||
session,
|
||||
asset,
|
||||
name="original.png",
|
||||
system_metadata={"kind": "image", "width": 1024, "height": 768},
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Simulate a from-hash path where the new reference already carries
|
||||
# some system_metadata (e.g. a download-provenance source_url written
|
||||
# by an earlier step). The back-fill must merge dim keys without
|
||||
# clobbering existing keys.
|
||||
result = _register_existing_asset(
|
||||
asset_hash="blake3:preserve",
|
||||
name="from_hash.png",
|
||||
owner_id="user-x",
|
||||
)
|
||||
ref = session.query(AssetReference).filter_by(id=result.ref.id).first()
|
||||
# Seed a sentinel key and re-run back-fill via a second register call
|
||||
# to exercise the merge path with pre-existing data.
|
||||
ref.system_metadata = {**(ref.system_metadata or {}), "source_url": "https://example/p"}
|
||||
session.commit()
|
||||
|
||||
assert ref.system_metadata.get("source_url") == "https://example/p"
|
||||
assert ref.system_metadata.get("kind") == "image"
|
||||
assert ref.system_metadata.get("width") == 1024
|
||||
assert ref.system_metadata.get("height") == 768
|
||||
|
||||
@ -141,7 +141,7 @@ class TestListTags:
|
||||
|
||||
rows, total = list_tags()
|
||||
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
tag_dict = {name: count for name, count in rows}
|
||||
assert tag_dict["used"] == 1
|
||||
assert tag_dict["unused"] == 0
|
||||
assert total == 2
|
||||
@ -155,7 +155,7 @@ class TestListTags:
|
||||
|
||||
rows, total = list_tags(include_zero=False)
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
tag_names = {name for name, _ in rows}
|
||||
assert "used" in tag_names
|
||||
assert "unused" not in tag_names
|
||||
|
||||
@ -165,7 +165,7 @@ class TestListTags:
|
||||
|
||||
rows, _ = list_tags(prefix="alph")
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
tag_names = {name for name, _ in rows}
|
||||
assert tag_names == {"alpha", "alphabet"}
|
||||
|
||||
def test_order_by_name(self, mock_create_session, session: Session):
|
||||
@ -174,7 +174,7 @@ class TestListTags:
|
||||
|
||||
rows, _ = list_tags(order="name_asc")
|
||||
|
||||
names = [name for name, _, _ in rows]
|
||||
names = [name for name, _ in rows]
|
||||
assert names == ["alpha", "middle", "zebra"]
|
||||
|
||||
def test_pagination(self, mock_create_session, session: Session):
|
||||
@ -185,7 +185,7 @@ class TestListTags:
|
||||
|
||||
assert total == 5
|
||||
assert len(rows) == 2
|
||||
names = [name for name, _, _ in rows]
|
||||
names = [name for name, _ in rows]
|
||||
assert names == ["b", "c"]
|
||||
|
||||
def test_clamps_limit(self, mock_create_session, session: Session):
|
||||
|
||||
@ -45,8 +45,8 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse
|
||||
assert "user_metadata" in detail
|
||||
assert "filename" in detail["user_metadata"]
|
||||
|
||||
# DELETE (hard delete to also remove underlying asset and file)
|
||||
rd = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120)
|
||||
# Soft delete — the reference is hidden, content is preserved
|
||||
rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
assert rd.status_code == 204
|
||||
|
||||
# GET again -> 404
|
||||
@ -60,7 +60,7 @@ def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seede
|
||||
aid = seeded_asset["id"]
|
||||
asset_hash = seeded_asset["asset_hash"]
|
||||
|
||||
# Soft-delete (default, no delete_content param)
|
||||
# Soft delete — the reference is hidden, content is preserved
|
||||
rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
assert rd.status_code == 204
|
||||
|
||||
@ -81,11 +81,10 @@ def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seede
|
||||
ids = [a["id"] for a in rl.json().get("assets", [])]
|
||||
assert aid not in ids
|
||||
|
||||
# Clean up: hard-delete the soft-deleted reference and orphaned asset
|
||||
http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120)
|
||||
# The reference is already soft-deleted; content is preserved.
|
||||
|
||||
|
||||
def test_delete_upon_reference_count(
|
||||
def test_soft_delete_preserves_asset_identity_across_references(
|
||||
http: requests.Session, api_base: str, seeded_asset: dict
|
||||
):
|
||||
# Create a second reference to the same asset via from-hash
|
||||
@ -119,16 +118,20 @@ def test_delete_upon_reference_count(
|
||||
rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
|
||||
assert rh2.status_code == 200 # asset identity preserved (soft delete)
|
||||
|
||||
# Re-associate via from-hash, then hard-delete -> orphan content removed
|
||||
# Re-associate via from-hash: it must reuse the same preserved content
|
||||
# (created_new False AND the same hash), proving the soft deletes did not
|
||||
# destroy the underlying asset. Then soft-delete again -> still preserved.
|
||||
r3 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||
assert r3.status_code == 201, r3.json()
|
||||
assert r3.json()["created_new"] is False
|
||||
assert r3.json()["asset_hash"] == src_hash # reused the surviving content
|
||||
aid3 = r3.json()["id"]
|
||||
|
||||
rd3 = http.delete(f"{api_base}/api/assets/{aid3}?delete_content=true", timeout=120)
|
||||
rd3 = http.delete(f"{api_base}/api/assets/{aid3}", timeout=120)
|
||||
assert rd3.status_code == 204
|
||||
|
||||
rh3 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
|
||||
assert rh3.status_code == 404 # orphan content removed
|
||||
assert rh3.status_code == 200 # content preserved (soft delete)
|
||||
|
||||
|
||||
def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
@ -249,7 +252,7 @@ def test_concurrent_delete_same_asset_info_single_204(
|
||||
|
||||
# Hit the same endpoint N times in parallel.
|
||||
n_tests = 4
|
||||
url = f"{api_base}/api/assets/{aid}?delete_content=false"
|
||||
url = f"{api_base}/api/assets/{aid}"
|
||||
|
||||
def _do_delete(delete_url):
|
||||
with requests.Session() as s:
|
||||
|
||||
@ -117,7 +117,7 @@ def test_download_missing_file_returns_404(
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
finally:
|
||||
# We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually.
|
||||
dr = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120)
|
||||
dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
dr.content
|
||||
|
||||
|
||||
|
||||
349
tests-unit/assets_test/test_list_cursor.py
Normal file
349
tests-unit/assets_test/test_list_cursor.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""Integration tests for cursor-based pagination on GET /api/assets.
|
||||
|
||||
These tests exercise the handler/service/query path end-to-end;
|
||||
cursor-encoding-level tests live in
|
||||
tests-unit/assets_test/services/test_cursor.py.
|
||||
"""
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]:
|
||||
names = [f"cursor_{i:02d}.safetensors" for i in range(count)]
|
||||
for n in names:
|
||||
asset_factory(
|
||||
n,
|
||||
["models", "checkpoints", "unit-tests", tag],
|
||||
{},
|
||||
make_asset_bytes(n, size=2048),
|
||||
)
|
||||
return sorted(names)
|
||||
|
||||
|
||||
def test_cursor_pages_all_items_in_order(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=5, tag="cursor-walk")
|
||||
|
||||
params = {
|
||||
"include_tags": "unit-tests,cursor-walk",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
}
|
||||
|
||||
seen: list[str] = []
|
||||
after: str | None = None
|
||||
pages = 0
|
||||
while True:
|
||||
page_params = dict(params)
|
||||
if after is not None:
|
||||
page_params["after"] = after
|
||||
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
seen.extend(a["name"] for a in body["assets"])
|
||||
pages += 1
|
||||
after = body.get("next_cursor")
|
||||
if after is None:
|
||||
break
|
||||
assert body["has_more"] is True
|
||||
assert pages < 10, "guard against runaway cursor loop"
|
||||
|
||||
assert seen == names, f"expected {names}, got {seen}"
|
||||
# Last page should have has_more False
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
def test_cursor_invalid_returns_400(http: requests.Session, api_base: str):
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": "not-a-real-cursor", "sort": "created_at"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
body = r.json()
|
||||
assert body["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_sort_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-mismatch")
|
||||
|
||||
# Take a real cursor minted for sort=name.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-mismatch",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Replay against sort=created_at — should fail with INVALID_CURSOR.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": cursor, "sort": "created_at"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 400, r2.text
|
||||
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_wins_over_offset(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-vs-offset")
|
||||
|
||||
# Take a cursor that points past the first item.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-vs-offset",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Pass both 'after' and a large offset. Cursor must win; offset is ignored.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-vs-offset",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
"after": cursor,
|
||||
"offset": "999",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200
|
||||
body = r2.json()
|
||||
# Should land on the second name in sorted order — not skip ahead by 999.
|
||||
assert [a["name"] for a in body["assets"]] == [names[1]]
|
||||
|
||||
|
||||
def test_next_cursor_absent_when_no_more_results(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-exhaust")
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-exhaust",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "50",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
def test_cursor_pagination_first_page_mints_cursor(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""First-page request (no `after`) must still return `next_cursor` when
|
||||
more rows exist, or pagination is unreachable from a cold start.
|
||||
"""
|
||||
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-first-page")
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-first-page", "sort": "name", "order": "asc", "limit": "2"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
assert body["has_more"] is True
|
||||
assert body.get("next_cursor"), "first page must mint a cursor when more rows exist"
|
||||
|
||||
|
||||
def test_cursor_no_spurious_cursor_when_page_size_equals_remainder(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""When `total` is an exact multiple of `limit`, the final page must
|
||||
NOT carry a next_cursor — there is nothing past it.
|
||||
"""
|
||||
_seed(asset_factory, make_asset_bytes, count=4, tag="cursor-exact-multiple")
|
||||
# Page 1
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
# Page 2 — should exhaust the set with no cursor for a phantom page 3
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2", "after": cursor},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200, r2.text
|
||||
body = r2.json()
|
||||
assert len(body["assets"]) == 2
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sort_field", ["created_at", "updated_at", "size"])
|
||||
def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""Cursor pagination must work for every sort field the contract claims.
|
||||
|
||||
Without this, the `created_at` / `updated_at` (time-encoded micros) and
|
||||
`size` (int-encoded) cursor paths go entirely unexercised end-to-end.
|
||||
"""
|
||||
# Sizes increase strictly by index, so `size desc` has a deterministic
|
||||
# expected order. Time-based sorts (created_at / updated_at) can tie when
|
||||
# rows are inserted faster than the DB's timestamp resolution; for those
|
||||
# we check coverage and no-duplicates and let the keyset tiebreaker do
|
||||
# the rest, instead of sleeping between inserts and asserting an order
|
||||
# that depends on clock granularity.
|
||||
names = []
|
||||
for i in range(4):
|
||||
n = f"cursor_{sort_field}_{i:02d}.safetensors"
|
||||
asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
|
||||
names.append(n)
|
||||
|
||||
params = {
|
||||
"include_tags": f"unit-tests,cursor-{sort_field}",
|
||||
"sort": sort_field,
|
||||
"order": "desc",
|
||||
"limit": "2",
|
||||
}
|
||||
seen: list[str] = []
|
||||
after: str | None = None
|
||||
pages = 0
|
||||
while True:
|
||||
page_params = dict(params)
|
||||
if after is not None:
|
||||
page_params["after"] = after
|
||||
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
seen.extend(a["name"] for a in body["assets"])
|
||||
after = body.get("next_cursor")
|
||||
pages += 1
|
||||
if after is None:
|
||||
break
|
||||
assert pages < 10, "guard against runaway cursor loop"
|
||||
|
||||
# No duplicates: a faulty keyset boundary that returns the same row across
|
||||
# two pages must fail this check.
|
||||
assert len(seen) == len(set(seen)), (
|
||||
f"cursor walk repeated rows for sort={sort_field}: {seen}"
|
||||
)
|
||||
# Full coverage: every seeded asset reached exactly once.
|
||||
assert set(seen) == set(names), (
|
||||
f"missing items for sort={sort_field}: expected {set(names)}, got {set(seen)}"
|
||||
)
|
||||
# Strict order check for the only field with a clock-independent ordering.
|
||||
if sort_field == "size":
|
||||
assert seen == list(reversed(names)), (
|
||||
f"size cursor walked out of order: got {seen}, expected {list(reversed(names))}"
|
||||
)
|
||||
|
||||
|
||||
def test_cursor_order_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""A cursor minted under desc order replayed against asc must 400, not
|
||||
silently walk the wrong direction."""
|
||||
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-order-flip")
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-order-flip",
|
||||
"sort": "name",
|
||||
"order": "desc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Replay with order flipped to asc — server must reject the cursor.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-order-flip",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
"after": cursor,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 400, r2.text
|
||||
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_invalid_cursor_at_microsecond_boundary(http: requests.Session, api_base: str):
|
||||
"""A cursor carrying an out-of-range microsecond timestamp must map to
|
||||
400 INVALID_CURSOR, not 500."""
|
||||
import base64
|
||||
import json
|
||||
# 10^18 microseconds ≈ year 33658, well past datetime.MAX_YEAR.
|
||||
# `o` and `order=` must be set; otherwise decode fails earlier on the
|
||||
# missing-order branch and the µs-overflow path is never exercised.
|
||||
payload = {"s": "created_at", "o": "desc", "v": "999999999999999999999", "id": "asset-x"}
|
||||
raw = json.dumps(payload, separators=(",", ":")).encode("utf-8")
|
||||
cursor = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": cursor, "sort": "created_at", "order": "desc"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
assert r.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_pagination_stable_after_delete(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-delete")
|
||||
|
||||
# Page 1.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-delete",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
page1_names = [a["name"] for a in body["assets"]]
|
||||
cursor = body["next_cursor"]
|
||||
assert cursor is not None
|
||||
assert page1_names == names[:2]
|
||||
|
||||
# Delete an item from page 1 (already returned) — cursor should still
|
||||
# locate the next page from where it was minted, not re-index.
|
||||
target_id = body["assets"][0]["id"]
|
||||
d = http.delete(api_base + f"/api/assets/{target_id}", timeout=120)
|
||||
assert d.status_code in (200, 204), d.text
|
||||
|
||||
# Page 2 via cursor.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-delete",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
"after": cursor,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200, r2.text
|
||||
body2 = r2.json()
|
||||
assert [a["name"] for a in body2["assets"]] == names[2:]
|
||||
69
tests-unit/assets_test/test_prompt_id_enforcement.py
Normal file
69
tests-unit/assets_test/test_prompt_id_enforcement.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""POST /prompt enforces canonical-UUID job ids at creation time.
|
||||
|
||||
Lives in assets_test because it uses this suite's booted-server fixture. The
|
||||
invariant itself is pipeline-wide: a job id is stored and compared verbatim
|
||||
downstream — history keys, websocket correlation, and /interrupt matching —
|
||||
so a job minted with a non-canonical id would miss every exact-match lookup.
|
||||
|
||||
The prompt bodies here are intentionally invalid workflows — prompt_id
|
||||
validation happens before workflow validation, so a rejected id returns
|
||||
``invalid_prompt_id`` while an accepted id falls through to the ordinary
|
||||
workflow-validation error (proving it cleared the id check).
|
||||
"""
|
||||
import requests
|
||||
|
||||
|
||||
def _post_prompt(http: requests.Session, api_base: str, body: dict) -> requests.Response:
|
||||
return http.post(api_base + "/prompt", json=body, timeout=30)
|
||||
|
||||
|
||||
def _error_type(r: requests.Response) -> str:
|
||||
return r.json()["error"]["type"]
|
||||
|
||||
|
||||
def test_non_uuid_prompt_id_rejected(http: requests.Session, api_base: str):
|
||||
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": "not-a-uuid"})
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) == "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_non_string_prompt_id_rejected(http: requests.Session, api_base: str):
|
||||
# Previously str()-coerced (123 became the job id "123"); must now be a 400,
|
||||
# not a 500 from uuid.UUID choking on a non-string.
|
||||
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": 123})
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) == "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_non_canonical_uuid_rejected(http: requests.Session, api_base: str):
|
||||
# Parseable as a UUID, but not the canonical lowercase form: rejected
|
||||
# loudly rather than silently rewritten (downstream lookups match the
|
||||
# stored id exactly).
|
||||
r = _post_prompt(
|
||||
http,
|
||||
api_base,
|
||||
{"prompt": {}, "prompt_id": "AAAAAAAA-BBBB-4CCC-8DDD-EEEEEEEEEEEE"},
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) == "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_canonical_uuid_accepted(http: requests.Session, api_base: str):
|
||||
# The id clears validation; the empty workflow then fails ordinary prompt
|
||||
# validation, proving the request got past the id check.
|
||||
r = _post_prompt(
|
||||
http,
|
||||
api_base,
|
||||
{"prompt": {}, "prompt_id": "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"},
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) != "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_null_prompt_id_not_rejected(http: requests.Session, api_base: str):
|
||||
# Explicit null means "server generates" and must not be rejected as an
|
||||
# invalid id. (The minted id itself is not observable here because the
|
||||
# workflow is invalid; unit tests cover validate_job_id directly.)
|
||||
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": None})
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) != "invalid_prompt_id"
|
||||
@ -95,7 +95,7 @@ def _make_asset(
|
||||
def _ensure_missing_tag(session: Session):
|
||||
"""Ensure the 'missing' tag exists."""
|
||||
if not session.get(Tag, "missing"):
|
||||
session.add(Tag(name="missing", tag_type="system"))
|
||||
session.add(Tag(name="missing"))
|
||||
session.flush()
|
||||
|
||||
|
||||
|
||||
@ -69,8 +69,8 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert custom_tag in used_names
|
||||
|
||||
# Hard-delete the asset so the tag usage drops to zero
|
||||
rd = http.delete(f"{api_base}/api/assets/{_asset['id']}?delete_content=true", timeout=120)
|
||||
# Delete the asset reference so the tag usage drops to zero
|
||||
rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120)
|
||||
assert rd.status_code == 204
|
||||
|
||||
# Now the custom tag must not be returned when include_zero=false
|
||||
|
||||
93
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
93
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
import torch
|
||||
import av
|
||||
import numpy as np
|
||||
from fractions import Fraction
|
||||
from comfy_api.latest._input_impl.video_types import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util.video_types import VideoComponents
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gradient_components():
|
||||
"""Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth"""
|
||||
width, height, frames = 64, 64, 3
|
||||
ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3)
|
||||
return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def src8(gradient_components, tmp_path_factory):
|
||||
"""8-bit h264 mp4 (Create Video default)"""
|
||||
path = str(tmp_path_factory.mktemp("video") / "src8.mp4")
|
||||
VideoFromComponents(gradient_components).save_to(path)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def src10(gradient_components, tmp_path_factory):
|
||||
"""10-bit h264 mp4 (Create Video with bit_depth=10)"""
|
||||
path = str(tmp_path_factory.mktemp("video") / "src10.mp4")
|
||||
VideoFromComponents(gradient_components, bit_depth=10).save_to(path)
|
||||
return path
|
||||
|
||||
|
||||
def probe(path):
|
||||
"""(codec, pix_fmt, bit_depth) of the first video stream"""
|
||||
with av.open(path) as container:
|
||||
stream = container.streams.video[0]
|
||||
return (stream.codec.name, stream.format.name, max(c.bits for c in stream.format.components))
|
||||
|
||||
|
||||
def decoded_levels(path):
|
||||
"""Unique tonal levels in the first decoded frame (banding measure)"""
|
||||
with av.open(path) as container:
|
||||
frame = next(container.decode(container.streams.video[0]))
|
||||
return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0]))
|
||||
|
||||
|
||||
def video_packet_bytes(path):
|
||||
"""Raw video packet payloads; identical to the source's only for a true remux"""
|
||||
with av.open(path) as container:
|
||||
return [bytes(p) for p in container.demux(container.streams.video[0]) if p.size]
|
||||
|
||||
|
||||
def test_create_video_bit_depth(src8, src10):
|
||||
"""Create Video's bit_depth picks the encoded depth (default 8-bit); 10-bit reduces banding"""
|
||||
assert probe(src8) == ("h264", "yuv420p", 8)
|
||||
assert probe(src10) == ("h264", "yuv420p10le", 10)
|
||||
assert decoded_levels(src10) > 2 * decoded_levels(src8)
|
||||
|
||||
|
||||
def test_save_auto_keeps_source_depth(src8, src10, tmp_path):
|
||||
"""Save Video (no bit_depth = auto) stream-copies the source, preserving its depth byte-for-byte"""
|
||||
for name, src in [("p8", src8), ("p10", src10)]:
|
||||
path = str(tmp_path / f"{name}.mp4")
|
||||
VideoFromFile(src).save_to(path)
|
||||
assert probe(path) == probe(src)
|
||||
assert video_packet_bytes(path) == video_packet_bytes(src)
|
||||
|
||||
|
||||
def test_save_explicit_depth_reencodes(src8, src10, tmp_path):
|
||||
"""An explicit bit_depth different from the source forces a re-encode to that depth"""
|
||||
down = str(tmp_path / "down8.mp4")
|
||||
VideoFromFile(src10).save_to(down, bit_depth=8)
|
||||
assert probe(down) == ("h264", "yuv420p", 8)
|
||||
|
||||
up = str(tmp_path / "up10.mp4")
|
||||
VideoFromFile(src8).save_to(up, bit_depth=10)
|
||||
assert probe(up) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
|
||||
def test_trim_keeps_source_depth(src10, tmp_path):
|
||||
"""Video Slice re-encodes (trim) but preserves the source's 10-bit depth"""
|
||||
path = str(tmp_path / "trim.mp4")
|
||||
VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path)
|
||||
assert probe(path) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
|
||||
def test_get_bit_depth(gradient_components, src8, src10):
|
||||
"""get_bit_depth reports a video's depth (backs the Get Video Components output)"""
|
||||
assert VideoFromFile(src8).get_bit_depth() == 8
|
||||
assert VideoFromFile(src10).get_bit_depth() == 10
|
||||
assert VideoFromComponents(gradient_components, bit_depth=10).get_bit_depth() == 10
|
||||
assert VideoFromComponents(gradient_components).get_bit_depth() == 8
|
||||
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user