mirror of
https://github.com/comfyanonymous/ComfyUI.git
synced 2026-06-23 16:29:25 +08:00
Merge branch 'master' into fix/ideogram4-llama-template
This commit is contained in:
commit
1c93e5f433
18
README.md
18
README.md
@ -364,7 +364,7 @@ For models compatible with Iluvatar Extension for PyTorch. Here's a step-by-step
|
||||
| Flag | Description |
|
||||
|------|-------------|
|
||||
| `--enable-manager` | Enable ComfyUI-Manager |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (requires `--enable-manager`) |
|
||||
| `--enable-manager-legacy-ui` | Use the legacy manager UI instead of the new UI (implies `--enable-manager`) |
|
||||
| `--disable-manager-ui` | Disable the manager UI and endpoints while keeping background features like security checks and scheduled installation completion (requires `--enable-manager`) |
|
||||
|
||||
|
||||
@ -382,11 +382,7 @@ For AMD 7600 and maybe other RDNA3 cards: ```HSA_OVERRIDE_GFX_VERSION=11.0.0 pyt
|
||||
|
||||
### AMD ROCm Tips
|
||||
|
||||
You can enable experimental memory efficient attention on recent pytorch in ComfyUI on some AMD GPUs using this command, it should already be enabled by default on RDNA3. If this improves speed for you on latest pytorch on your GPU please report it so that I can enable it by default.
|
||||
|
||||
```TORCH_ROCM_AOTRITON_ENABLE_EXPERIMENTAL=1 python main.py --use-pytorch-cross-attention```
|
||||
|
||||
You can also try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
You can try setting this env variable `PYTORCH_TUNABLEOP_ENABLED=1` which might speed things up at the cost of a very slow initial run.
|
||||
|
||||
# Notes
|
||||
|
||||
@ -462,16 +458,6 @@ To use the most up-to-date frontend version:
|
||||
|
||||
This approach allows you to easily switch between the stable fortnightly release and the cutting-edge daily updates, or even specific versions for testing purposes.
|
||||
|
||||
### Accessing the Legacy Frontend
|
||||
|
||||
If you need to use the legacy frontend for any reason, you can access it using the following command line argument:
|
||||
|
||||
```
|
||||
--front-end-version Comfy-Org/ComfyUI_legacy_frontend@latest
|
||||
```
|
||||
|
||||
This will use a snapshot of the legacy frontend preserved in the [ComfyUI Legacy Frontend repository](https://github.com/Comfy-Org/ComfyUI_legacy_frontend).
|
||||
|
||||
# QA
|
||||
|
||||
### Which GPU should I buy for this?
|
||||
|
||||
39
alembic_db/versions/0004_drop_tag_type.py
Normal file
39
alembic_db/versions/0004_drop_tag_type.py
Normal file
@ -0,0 +1,39 @@
|
||||
"""
|
||||
Drop the vestigial tags.tag_type column.
|
||||
|
||||
tag_type was always "user" in practice — no code path ever set it to anything
|
||||
else (no system/seeded classification was ever wired up) and nothing queried it.
|
||||
The column, its index (ix_tags_tag_type), and the corresponding API field were
|
||||
dead weight, so they are removed.
|
||||
|
||||
Revision ID: 0004_drop_tag_type
|
||||
Revises: 0003_add_metadata_job_id
|
||||
Create Date: 2026-06-03
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
revision = "0004_drop_tag_type"
|
||||
down_revision = "0003_add_metadata_job_id"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
with op.batch_alter_table("tags") as batch_op:
|
||||
batch_op.drop_index("ix_tags_tag_type")
|
||||
batch_op.drop_column("tag_type")
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
with op.batch_alter_table("tags") as batch_op:
|
||||
batch_op.add_column(
|
||||
sa.Column(
|
||||
"tag_type",
|
||||
sa.String(length=32),
|
||||
nullable=False,
|
||||
server_default="user",
|
||||
)
|
||||
)
|
||||
batch_op.create_index("ix_tags_tag_type", ["tag_type"])
|
||||
@ -39,6 +39,7 @@ from app.assets.services import (
|
||||
update_asset_metadata,
|
||||
upload_from_temp_path,
|
||||
)
|
||||
from app.assets.services.cursor import InvalidCursorError
|
||||
from app.assets.services.tagging import list_tag_histogram
|
||||
|
||||
ROUTES = web.RouteTableDef()
|
||||
@ -174,7 +175,7 @@ def _build_asset_response(result: schemas.AssetDetailResult | schemas.UploadResu
|
||||
user_metadata=result.ref.user_metadata or {},
|
||||
metadata=result.ref.system_metadata,
|
||||
job_id=result.ref.job_id,
|
||||
prompt_id=result.ref.job_id, # deprecated: mirrors job_id for cloud compat
|
||||
prompt_id=result.ref.job_id, # deprecated alias of job_id, kept for compatibility
|
||||
created_at=result.ref.created_at,
|
||||
updated_at=result.ref.updated_at,
|
||||
last_access_time=result.ref.last_access_time,
|
||||
@ -211,24 +212,37 @@ async def list_assets_route(request: web.Request) -> web.Response:
|
||||
order_candidate = (q.order or "desc").lower()
|
||||
order = order_candidate if order_candidate in {"asc", "desc"} else "desc"
|
||||
|
||||
result = list_assets_page(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
)
|
||||
try:
|
||||
result = list_assets_page(
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
include_tags=q.include_tags,
|
||||
exclude_tags=q.exclude_tags,
|
||||
name_contains=q.name_contains,
|
||||
metadata_filter=q.metadata_filter,
|
||||
limit=q.limit,
|
||||
offset=q.offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
after=q.after,
|
||||
)
|
||||
except InvalidCursorError as e:
|
||||
return _build_error_response(400, "INVALID_CURSOR", str(e))
|
||||
|
||||
summaries = [_build_asset_response(item) for item in result.items]
|
||||
|
||||
# has_more semantics differ by mode:
|
||||
# - cursor mode: a non-empty next_cursor means there are more results.
|
||||
# - offset mode: derived from total - (offset + page size).
|
||||
if q.after is not None:
|
||||
has_more = result.next_cursor is not None
|
||||
else:
|
||||
has_more = (q.offset + len(summaries)) < result.total
|
||||
|
||||
payload = schemas_out.AssetsList(
|
||||
assets=summaries,
|
||||
total=result.total,
|
||||
has_more=(q.offset + len(summaries)) < result.total,
|
||||
has_more=has_more,
|
||||
next_cursor=result.next_cursor,
|
||||
)
|
||||
return web.json_response(payload.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
@ -519,18 +533,14 @@ async def update_asset_route(request: web.Request) -> web.Response:
|
||||
@_require_assets_feature_enabled
|
||||
async def delete_asset_route(request: web.Request) -> web.Response:
|
||||
reference_id = str(uuid.UUID(request.match_info["id"]))
|
||||
delete_content_param = request.query.get("delete_content")
|
||||
delete_content = (
|
||||
False
|
||||
if delete_content_param is None
|
||||
else delete_content_param.lower() not in {"0", "false", "no"}
|
||||
)
|
||||
|
||||
try:
|
||||
# Deleting an asset is a soft delete of the reference; the underlying
|
||||
# content is preserved (it may be shared with other references).
|
||||
deleted = delete_asset_reference(
|
||||
reference_id=reference_id,
|
||||
owner_id=USER_MANAGER.get_request_user_id(request),
|
||||
delete_content_if_orphan=delete_content,
|
||||
delete_content_if_orphan=False,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(
|
||||
@ -575,8 +585,8 @@ async def get_tags(request: web.Request) -> web.Response:
|
||||
)
|
||||
|
||||
tags = [
|
||||
schemas_out.TagUsage(name=name, count=count, type=tag_type)
|
||||
for (name, tag_type, count) in rows
|
||||
schemas_out.TagUsage(name=name, count=count)
|
||||
for (name, count) in rows
|
||||
]
|
||||
payload = schemas_out.TagsList(
|
||||
tags=tags, total=total, has_more=(query.offset + len(tags)) < total
|
||||
|
||||
@ -59,6 +59,11 @@ class ListAssetsQuery(BaseModel):
|
||||
|
||||
limit: conint(ge=1, le=500) = 20
|
||||
offset: conint(ge=0) = 0
|
||||
# Opaque keyset cursor. When supplied, `offset` is ignored. Cursor pagination
|
||||
# is supported for sort values `created_at`, `updated_at`, `name`, `size`.
|
||||
# Supplying `after` together with `sort=last_access_time` returns
|
||||
# 400 INVALID_CURSOR; that sort only supports offset/limit.
|
||||
after: str | None = None
|
||||
|
||||
sort: Literal["name", "created_at", "updated_at", "size", "last_access_time"] = (
|
||||
"created_at"
|
||||
|
||||
@ -41,12 +41,13 @@ class AssetsList(BaseModel):
|
||||
assets: list[Asset]
|
||||
total: int
|
||||
has_more: bool
|
||||
# Opaque cursor for the next page. Omitted when there are no more results.
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
class TagUsage(BaseModel):
|
||||
name: str
|
||||
count: int
|
||||
type: str
|
||||
|
||||
|
||||
class TagsList(BaseModel):
|
||||
|
||||
@ -227,7 +227,6 @@ class Tag(Base):
|
||||
__tablename__ = "tags"
|
||||
|
||||
name: Mapped[str] = mapped_column(String(512), primary_key=True)
|
||||
tag_type: Mapped[str] = mapped_column(String(32), nullable=False, default="user")
|
||||
|
||||
asset_reference_links: Mapped[list[AssetReferenceTag]] = relationship(
|
||||
back_populates="tag",
|
||||
@ -240,7 +239,5 @@ class Tag(Base):
|
||||
overlaps="asset_reference_links,tag_links,tags,asset_reference",
|
||||
)
|
||||
|
||||
__table_args__ = (Index("ix_tags_tag_type", "tag_type"),)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Tag {self.name}>"
|
||||
|
||||
@ -266,9 +266,18 @@ def list_references_page(
|
||||
metadata_filter: dict | None = None,
|
||||
sort: str | None = None,
|
||||
order: str | None = None,
|
||||
after_cursor_value: object | None = None,
|
||||
after_cursor_id: str | None = None,
|
||||
) -> tuple[list[AssetReference], dict[str, list[str]], int]:
|
||||
"""List references with pagination, filtering, and sorting.
|
||||
|
||||
When ``after_cursor_value``/``after_cursor_id`` are supplied the query uses
|
||||
keyset pagination — ``offset`` is ignored and a WHERE clause selects rows
|
||||
strictly after the given ``(sort_col, id)`` position in the active sort
|
||||
direction. The cursor value must already be typed for the column
|
||||
(datetime for time sorts, int for size, str for name); the caller decodes
|
||||
the opaque cursor string and resolves to the typed value.
|
||||
|
||||
Returns (references, tag_map, total_count).
|
||||
"""
|
||||
base = (
|
||||
@ -297,9 +306,31 @@ def list_references_page(
|
||||
"size": Asset.size_bytes,
|
||||
}
|
||||
sort_col = sort_map.get(sort, AssetReference.created_at)
|
||||
sort_exp = sort_col.desc() if order == "desc" else sort_col.asc()
|
||||
descending = order == "desc"
|
||||
|
||||
base = base.order_by(sort_exp).limit(limit).offset(offset)
|
||||
# Keyset WHERE: (sort_col, id) strictly less-than / greater-than the cursor.
|
||||
# Equivalent to: sort_col <op> v OR (sort_col = v AND id <op> cursor_id).
|
||||
if after_cursor_value is not None and after_cursor_id is not None:
|
||||
if descending:
|
||||
keyset = sa.or_(
|
||||
sort_col < after_cursor_value,
|
||||
sa.and_(sort_col == after_cursor_value, AssetReference.id < after_cursor_id),
|
||||
)
|
||||
else:
|
||||
keyset = sa.or_(
|
||||
sort_col > after_cursor_value,
|
||||
sa.and_(sort_col == after_cursor_value, AssetReference.id > after_cursor_id),
|
||||
)
|
||||
base = base.where(keyset)
|
||||
|
||||
# Secondary ORDER BY id (matching the primary direction) gives the keyset
|
||||
# comparison a deterministic tiebreaker on duplicate sort_col values.
|
||||
id_exp = AssetReference.id.desc() if descending else AssetReference.id.asc()
|
||||
sort_exp = sort_col.desc() if descending else sort_col.asc()
|
||||
|
||||
base = base.order_by(sort_exp, id_exp).limit(limit)
|
||||
if after_cursor_id is None:
|
||||
base = base.offset(offset)
|
||||
|
||||
count_stmt = (
|
||||
select(sa.func.count())
|
||||
|
||||
@ -55,13 +55,11 @@ def validate_tags_exist(session: Session, tags: list[str]) -> None:
|
||||
raise ValueError(f"Unknown tags: {missing}")
|
||||
|
||||
|
||||
def ensure_tags_exist(
|
||||
session: Session, names: Iterable[str], tag_type: str = "user"
|
||||
) -> None:
|
||||
def ensure_tags_exist(session: Session, names: Iterable[str]) -> None:
|
||||
wanted = normalize_tags(list(names))
|
||||
if not wanted:
|
||||
return
|
||||
rows = [{"name": n, "tag_type": tag_type} for n in list(dict.fromkeys(wanted))]
|
||||
rows = [{"name": n} for n in list(dict.fromkeys(wanted))]
|
||||
ins = (
|
||||
sqlite.insert(Tag)
|
||||
.values(rows)
|
||||
@ -97,7 +95,7 @@ def set_reference_tags(
|
||||
to_remove = [t for t in current if t not in desired]
|
||||
|
||||
if to_add:
|
||||
ensure_tags_exist(session, to_add, tag_type="user")
|
||||
ensure_tags_exist(session, to_add)
|
||||
session.add_all(
|
||||
[
|
||||
AssetReferenceTag(
|
||||
@ -142,7 +140,7 @@ def add_tags_to_reference(
|
||||
return AddTagsResult(added=[], already_present=[], total_tags=total)
|
||||
|
||||
if create_if_missing:
|
||||
ensure_tags_exist(session, norm, tag_type="user")
|
||||
ensure_tags_exist(session, norm)
|
||||
|
||||
current = set(get_reference_tags(session, reference_id))
|
||||
|
||||
@ -289,7 +287,6 @@ def list_tags_with_usage(
|
||||
q = (
|
||||
select(
|
||||
Tag.name,
|
||||
Tag.tag_type,
|
||||
func.coalesce(counts_sq.c.cnt, 0).label("count"),
|
||||
)
|
||||
.select_from(Tag)
|
||||
@ -331,7 +328,7 @@ def list_tags_with_usage(
|
||||
rows = (session.execute(q.limit(limit).offset(offset))).all()
|
||||
total = (session.execute(total_q)).scalar_one()
|
||||
|
||||
rows_norm = [(name, ttype, int(count or 0)) for (name, ttype, count) in rows]
|
||||
rows_norm = [(name, int(count or 0)) for (name, count) in rows]
|
||||
return rows_norm, int(total or 0)
|
||||
|
||||
|
||||
|
||||
@ -355,7 +355,7 @@ def insert_asset_specs(specs: list[SeedAssetSpec], tag_pool: set[str]) -> int:
|
||||
return 0
|
||||
with create_session() as sess:
|
||||
if tag_pool:
|
||||
ensure_tags_exist(sess, tag_pool, tag_type="user")
|
||||
ensure_tags_exist(sess, tag_pool)
|
||||
result = batch_insert_seed_assets(sess, specs=specs, owner_id="")
|
||||
sess.commit()
|
||||
return result.inserted_refs
|
||||
|
||||
@ -1,8 +1,19 @@
|
||||
import contextlib
|
||||
import mimetypes
|
||||
import os
|
||||
from datetime import timezone
|
||||
from typing import Sequence
|
||||
|
||||
from app.assets.services.cursor import (
|
||||
CursorPayload,
|
||||
InvalidCursorError,
|
||||
decode_cursor,
|
||||
decode_cursor_int,
|
||||
decode_cursor_time,
|
||||
encode_cursor,
|
||||
encode_cursor_from_time,
|
||||
)
|
||||
|
||||
|
||||
from app.assets.database.models import Asset
|
||||
from app.assets.database.queries import (
|
||||
@ -149,6 +160,16 @@ def delete_asset_reference(
|
||||
owner_id: str,
|
||||
delete_content_if_orphan: bool = True,
|
||||
) -> bool:
|
||||
"""Delete an asset reference.
|
||||
|
||||
With ``delete_content_if_orphan=False`` (a soft delete), the reference is
|
||||
hidden and the underlying content is preserved. With ``True``, the content
|
||||
is also removed once it becomes orphaned.
|
||||
|
||||
Note: the public DELETE /api/assets/{id} endpoint always soft-deletes
|
||||
(passes ``False``); the orphan-reclamation path is intentionally
|
||||
internal-only, retained for a future GC/admin caller.
|
||||
"""
|
||||
with create_session() as session:
|
||||
if not delete_content_if_orphan:
|
||||
# Soft delete: mark the reference as deleted but keep everything
|
||||
@ -242,6 +263,11 @@ def get_asset_by_hash(asset_hash: str) -> AssetData | None:
|
||||
return extract_asset_data(asset)
|
||||
|
||||
|
||||
# Sort fields that support cursor pagination. `last_access_time` is not
|
||||
# in this list — it falls back to offset/limit.
|
||||
_CURSOR_SORT_FIELDS = ("created_at", "updated_at", "name", "size")
|
||||
|
||||
|
||||
def list_assets_page(
|
||||
owner_id: str = "",
|
||||
include_tags: Sequence[str] | None = None,
|
||||
@ -252,7 +278,39 @@ def list_assets_page(
|
||||
offset: int = 0,
|
||||
sort: str = "created_at",
|
||||
order: str = "desc",
|
||||
after: str | None = None,
|
||||
) -> ListAssetsResult:
|
||||
"""List assets with optional cursor pagination.
|
||||
|
||||
When ``after`` is supplied it overrides ``offset``. The cursor's sort field
|
||||
must match ``sort`` and be in the cursor-supported allowlist; mismatches
|
||||
raise InvalidCursorError so the handler can map to 400 INVALID_CURSOR.
|
||||
"""
|
||||
cursor_value: object | None = None
|
||||
cursor_id: str | None = None
|
||||
# Mint next_cursor on every page where the sort is cursor-supported, not
|
||||
# only when the request itself arrived with a cursor. Otherwise a first
|
||||
# request (no `after`) returns next_cursor=None and the client can never
|
||||
# enter cursor mode.
|
||||
mint_cursor = sort in _CURSOR_SORT_FIELDS
|
||||
|
||||
if after is not None:
|
||||
if sort not in _CURSOR_SORT_FIELDS:
|
||||
raise InvalidCursorError(
|
||||
f"cursor pagination is not supported for sort={sort!r}"
|
||||
)
|
||||
payload = decode_cursor(after, _CURSOR_SORT_FIELDS, expected_order=order)
|
||||
if payload.sort_field != sort:
|
||||
raise InvalidCursorError(
|
||||
f"cursor sort field {payload.sort_field!r} does not match request sort {sort!r}"
|
||||
)
|
||||
cursor_value, cursor_id = _resolve_cursor_value(payload), payload.id
|
||||
|
||||
# Over-fetch by one row so we can distinguish "exactly `limit` rows total
|
||||
# remaining" from "more rows past this page" without a second query. Drop
|
||||
# the sentinel before returning.
|
||||
fetch_limit = limit + 1 if mint_cursor else limit
|
||||
|
||||
with create_session() as session:
|
||||
refs, tag_map, total = list_references_page(
|
||||
session,
|
||||
@ -261,12 +319,22 @@ def list_assets_page(
|
||||
exclude_tags=exclude_tags,
|
||||
name_contains=name_contains,
|
||||
metadata_filter=metadata_filter,
|
||||
limit=limit,
|
||||
limit=fetch_limit,
|
||||
offset=offset,
|
||||
sort=sort,
|
||||
order=order,
|
||||
after_cursor_value=cursor_value,
|
||||
after_cursor_id=cursor_id,
|
||||
)
|
||||
|
||||
next_cursor: str | None = None
|
||||
if mint_cursor and len(refs) > limit:
|
||||
# There's at least one more row past this page — mint a cursor from
|
||||
# the last row of the page (i.e. index `limit - 1`, since we
|
||||
# over-fetched), and drop the sentinel.
|
||||
next_cursor = _encode_next_cursor(refs[limit - 1], sort, order)
|
||||
refs = refs[:limit]
|
||||
|
||||
items: list[AssetSummaryData] = []
|
||||
for ref in refs:
|
||||
items.append(
|
||||
@ -277,7 +345,39 @@ def list_assets_page(
|
||||
)
|
||||
)
|
||||
|
||||
return ListAssetsResult(items=items, total=total)
|
||||
return ListAssetsResult(items=items, total=total, next_cursor=next_cursor)
|
||||
|
||||
|
||||
def _resolve_cursor_value(payload: CursorPayload) -> object:
|
||||
"""Map a decoded cursor payload to a column-typed Python value."""
|
||||
if payload.sort_field in ("created_at", "updated_at"):
|
||||
# DB stores naive UTC; strip tzinfo so the comparison binds against a
|
||||
# `TIMESTAMP WITHOUT TIME ZONE` column without an offset shift.
|
||||
return decode_cursor_time(payload).replace(tzinfo=None)
|
||||
if payload.sort_field == "size":
|
||||
return decode_cursor_int(payload)
|
||||
return payload.value # name, str-typed
|
||||
|
||||
|
||||
def _encode_next_cursor(ref, sort: str, order: str) -> str | None:
|
||||
"""Mint a cursor pointing at *ref* for the given sort dimension.
|
||||
|
||||
Returns None when the boundary row carries a NULL sort value (e.g. an asset
|
||||
record whose size_bytes hasn't been backfilled). Continuing pagination
|
||||
across a NULL boundary is undefined under keyset ordering — better to
|
||||
truncate cleanly here than to mint a cursor that mis-positions.
|
||||
"""
|
||||
if sort == "name":
|
||||
return encode_cursor("name", ref.name, ref.id, order=order)
|
||||
if sort == "size":
|
||||
if ref.asset is None or ref.asset.size_bytes is None:
|
||||
return None
|
||||
return encode_cursor("size", str(ref.asset.size_bytes), ref.id, order=order)
|
||||
# created_at / updated_at — DB datetimes are naive UTC; attach tz before encoding.
|
||||
value = ref.created_at if sort == "created_at" else ref.updated_at
|
||||
if value is None:
|
||||
return None
|
||||
return encode_cursor_from_time(sort, value.replace(tzinfo=timezone.utc), ref.id, order=order)
|
||||
|
||||
|
||||
def resolve_hash_to_path(
|
||||
|
||||
213
app/assets/services/cursor.py
Normal file
213
app/assets/services/cursor.py
Normal file
@ -0,0 +1,213 @@
|
||||
"""Opaque keyset-pagination cursor for /api/assets.
|
||||
|
||||
Payload JSON uses short keys to keep the encoded length small:
|
||||
|
||||
{"s": <sort_field>, "v": <value>, "id": <id>, "o": <order>}
|
||||
|
||||
The `o` key binds the cursor to the sort direction it was minted under,
|
||||
so replaying a `desc` cursor against an `asc` request fails with
|
||||
``INVALID_CURSOR`` rather than silently walking the wrong direction.
|
||||
`o` is mandatory on every payload — a cursor without it is rejected as
|
||||
malformed.
|
||||
|
||||
Encoding is base64url with no padding. Cursors are opaque tokens: the
|
||||
payload format is internal to this server, and clients must treat a
|
||||
cursor as a black box handed back via `next_cursor`. No byte-level
|
||||
compatibility with any other implementation is required.
|
||||
|
||||
Time values are serialized as Unix microseconds (UTC) — microsecond
|
||||
precision is sufficient to round-trip the timestamps stored by the
|
||||
database without rounding rows in the same millisecond bucket.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Iterable, Optional
|
||||
|
||||
|
||||
class InvalidCursorError(ValueError):
|
||||
"""Raised on a malformed, oversized, or unsupported-sort-field cursor.
|
||||
|
||||
Map to a 400 response with code ``INVALID_CURSOR`` at the handler.
|
||||
"""
|
||||
|
||||
|
||||
# Wire-format length caps. Cursors are user-controlled, so caps protect the
|
||||
# decode path from oversized allocations and downstream SQL predicates from
|
||||
# unbounded strings.
|
||||
#
|
||||
# MAX_CURSOR_VALUE_LENGTH is 512 to fit the `AssetReference.name` column max
|
||||
# (`String(512)`) — otherwise a long-named asset would mint a cursor the same
|
||||
# server then refuses on the next request.
|
||||
#
|
||||
# MAX_ENCODED_CURSOR_LENGTH is the decode-path guard, sized comfortably above
|
||||
# the largest cursor the per-field caps can produce. Worst case is value + id
|
||||
# at their caps with every character JSON-escaping to the six-byte `\uXXXX`
|
||||
# form (control characters), which is ~5.2 KB once base64url-encoded. At 8192
|
||||
# the encoder can never mint a cursor that exceeds it, so a freshly minted
|
||||
# cursor always decodes on the next request and there is no user-visible
|
||||
# "cursor too long" failure.
|
||||
MAX_ENCODED_CURSOR_LENGTH = 8192
|
||||
MAX_CURSOR_VALUE_LENGTH = 512
|
||||
MAX_CURSOR_ID_LENGTH = 128
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class CursorPayload:
|
||||
sort_field: str
|
||||
value: str
|
||||
id: str
|
||||
order: str
|
||||
|
||||
|
||||
_VALID_ORDERS = ("asc", "desc")
|
||||
|
||||
|
||||
def encode_cursor(sort_field: str, value: str, id: str, order: str = "desc") -> str:
|
||||
"""Encode a cursor payload as a base64url (no-padding) string.
|
||||
|
||||
`order` binds the cursor to the sort direction it was minted under so a
|
||||
later request with a flipped `order` query parameter is rejected with
|
||||
``INVALID_CURSOR`` rather than silently walking the wrong direction.
|
||||
"""
|
||||
if order not in _VALID_ORDERS:
|
||||
raise InvalidCursorError(f"order must be one of {_VALID_ORDERS}, got {order!r}")
|
||||
# Symmetric input validation: the encoder must reject anything the
|
||||
# decoder rejects, or the same server will mint cursors it then 400s on
|
||||
# the next request.
|
||||
if not id:
|
||||
raise InvalidCursorError("id must be non-empty")
|
||||
if len(id) > MAX_CURSOR_ID_LENGTH:
|
||||
raise InvalidCursorError("id exceeds maximum length")
|
||||
if len(value) > MAX_CURSOR_VALUE_LENGTH:
|
||||
raise InvalidCursorError("value exceeds maximum length")
|
||||
payload = {"s": sort_field, "v": value, "id": id, "o": order}
|
||||
raw = json.dumps(payload, separators=(",", ":"), ensure_ascii=False)
|
||||
# No mint-time length guard is needed: the per-field caps above bound the
|
||||
# encoded length well below MAX_ENCODED_CURSOR_LENGTH (see its definition),
|
||||
# so the encoder can never produce a cursor the decode path would reject.
|
||||
return base64.urlsafe_b64encode(raw.encode("utf-8")).rstrip(b"=").decode("ascii")
|
||||
|
||||
|
||||
def encode_cursor_from_time(sort_field: str, t: datetime, id: str, order: str = "desc") -> str:
|
||||
"""Encode a time-typed cursor at Unix microsecond precision.
|
||||
|
||||
Accepts an aware datetime (any timezone) and normalizes to UTC. Naive
|
||||
datetimes are rejected so callers can't accidentally encode the local
|
||||
wall-clock value of a UTC-stored timestamp.
|
||||
"""
|
||||
if t.tzinfo is None:
|
||||
raise ValueError("encode_cursor_from_time requires an aware datetime")
|
||||
micros = _datetime_to_unix_micros(t.astimezone(timezone.utc))
|
||||
return encode_cursor(sort_field, str(micros), id, order=order)
|
||||
|
||||
|
||||
def decode_cursor(
|
||||
cursor: str,
|
||||
allowed_sort_fields: Iterable[str],
|
||||
expected_order: str | None = None,
|
||||
) -> CursorPayload:
|
||||
"""Parse an opaque cursor.
|
||||
|
||||
``allowed_sort_fields`` is the endpoint's accepted sort-field list — a
|
||||
cursor carrying a field outside this set is rejected so a cursor minted
|
||||
for one column can't be replayed against another (e.g. a ``created_at``
|
||||
timestamp string compared against a ``name`` column).
|
||||
|
||||
``expected_order`` (``"asc"``/``"desc"``), when supplied, must match the
|
||||
payload's ``o`` field. ``o`` is required on every payload; a cursor
|
||||
missing it is rejected as malformed.
|
||||
|
||||
Passing no allowed fields rejects every cursor.
|
||||
"""
|
||||
if len(cursor) > MAX_ENCODED_CURSOR_LENGTH:
|
||||
raise InvalidCursorError("cursor exceeds maximum length")
|
||||
|
||||
try:
|
||||
# urlsafe_b64decode requires correct padding; we strip on encode, so
|
||||
# restore the trailing '=' pad here.
|
||||
padding = "=" * (-len(cursor) % 4)
|
||||
raw = base64.urlsafe_b64decode(cursor + padding)
|
||||
except (ValueError, base64.binascii.Error) as e:
|
||||
raise InvalidCursorError(f"encoding: {e}") from e
|
||||
|
||||
try:
|
||||
decoded = json.loads(raw)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError) as e:
|
||||
raise InvalidCursorError(f"payload: {e}") from e
|
||||
|
||||
if not isinstance(decoded, dict):
|
||||
raise InvalidCursorError("payload: expected object")
|
||||
|
||||
sort_field = decoded.get("s")
|
||||
value = decoded.get("v")
|
||||
id = decoded.get("id")
|
||||
order = decoded.get("o")
|
||||
|
||||
if not isinstance(sort_field, str) or not isinstance(value, str) or not isinstance(id, str):
|
||||
raise InvalidCursorError("payload: missing or non-string s/v/id")
|
||||
|
||||
if id == "":
|
||||
raise InvalidCursorError("missing id")
|
||||
if len(id) > MAX_CURSOR_ID_LENGTH:
|
||||
raise InvalidCursorError("id exceeds maximum length")
|
||||
if len(value) > MAX_CURSOR_VALUE_LENGTH:
|
||||
raise InvalidCursorError("value exceeds maximum length")
|
||||
|
||||
if sort_field not in allowed_sort_fields:
|
||||
raise InvalidCursorError(f"unsupported sort field {sort_field!r}")
|
||||
|
||||
if not isinstance(order, str):
|
||||
raise InvalidCursorError("missing or non-string o")
|
||||
if order not in _VALID_ORDERS:
|
||||
raise InvalidCursorError(f"unsupported order {order!r}")
|
||||
if expected_order is not None and order != expected_order:
|
||||
raise InvalidCursorError(
|
||||
f"cursor order {order!r} does not match request order {expected_order!r}"
|
||||
)
|
||||
|
||||
return CursorPayload(sort_field=sort_field, value=value, id=id, order=order)
|
||||
|
||||
|
||||
def decode_cursor_time(payload: Optional[CursorPayload]) -> datetime:
|
||||
"""Parse a time-typed cursor value as Unix microseconds, returning UTC."""
|
||||
if payload is None:
|
||||
raise InvalidCursorError("nil cursor payload")
|
||||
try:
|
||||
micros = int(payload.value)
|
||||
except ValueError as e:
|
||||
raise InvalidCursorError(f"value is not a valid timestamp: {e}") from e
|
||||
try:
|
||||
return _unix_micros_to_datetime(micros)
|
||||
except (OverflowError, OSError, ValueError) as e:
|
||||
# Crafted out-of-range microseconds (e.g. > datetime.MAX_YEAR) blow up
|
||||
# in fromtimestamp / datetime construction. Map to 400, not 500.
|
||||
raise InvalidCursorError(f"value is out of representable range: {e}") from e
|
||||
|
||||
|
||||
def decode_cursor_int(payload: Optional[CursorPayload]) -> int:
|
||||
"""Parse a cursor value as a base-10 integer."""
|
||||
if payload is None:
|
||||
raise InvalidCursorError("nil cursor payload")
|
||||
try:
|
||||
return int(payload.value)
|
||||
except ValueError as e:
|
||||
raise InvalidCursorError(f"value is not a valid integer: {e}") from e
|
||||
|
||||
|
||||
_EPOCH = datetime(1970, 1, 1, tzinfo=timezone.utc)
|
||||
|
||||
|
||||
def _datetime_to_unix_micros(t: datetime) -> int:
|
||||
"""Convert an aware UTC datetime to Unix microseconds (integer math)."""
|
||||
delta = t - _EPOCH
|
||||
return (delta.days * 86_400 + delta.seconds) * 1_000_000 + delta.microseconds
|
||||
|
||||
|
||||
def _unix_micros_to_datetime(micros: int) -> datetime:
|
||||
"""Convert Unix microseconds to a UTC datetime, preserving precision."""
|
||||
seconds, micro_remainder = divmod(micros, 1_000_000)
|
||||
return datetime.fromtimestamp(seconds, tz=timezone.utc).replace(microsecond=micro_remainder)
|
||||
@ -56,7 +56,6 @@ class IngestResult:
|
||||
|
||||
class TagUsage(NamedTuple):
|
||||
name: str
|
||||
tag_type: str
|
||||
count: int
|
||||
|
||||
|
||||
@ -71,6 +70,7 @@ class AssetSummaryData:
|
||||
class ListAssetsResult:
|
||||
items: list[AssetSummaryData]
|
||||
total: int
|
||||
next_cursor: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
|
||||
@ -75,7 +75,7 @@ def list_tags(
|
||||
owner_id=owner_id,
|
||||
)
|
||||
|
||||
return [TagUsage(name, tag_type, count) for name, tag_type, count in rows], total
|
||||
return [TagUsage(name, count) for name, count in rows], total
|
||||
|
||||
|
||||
def list_tag_histogram(
|
||||
|
||||
@ -115,6 +115,7 @@ cache_group.add_argument("--cache-ram", nargs='*', type=float, default=[], metav
|
||||
cache_group.add_argument("--cache-classic", action="store_true", help="Use the old style (aggressive) caching.")
|
||||
cache_group.add_argument("--cache-lru", type=int, default=0, help="Use LRU caching with a maximum of N node results cached. May use more RAM/VRAM.")
|
||||
cache_group.add_argument("--cache-none", action="store_true", help="Reduced RAM/VRAM usage at the expense of executing every node for each run.")
|
||||
cache_group.add_argument("--high-ram", action="store_true", help="Can improve performance slightly on high RAM or on systems where pagefile use is preferred over model loading.")
|
||||
|
||||
attn_group = parser.add_mutually_exclusive_group()
|
||||
attn_group.add_argument("--use-split-cross-attention", action="store_true", help="Use the split cross attention optimization. Ignored when xformers is used.")
|
||||
@ -133,7 +134,7 @@ upcast.add_argument("--dont-upcast-attention", action="store_true", help="Disabl
|
||||
parser.add_argument("--enable-manager", action="store_true", help="Enable the ComfyUI-Manager feature.")
|
||||
manager_group = parser.add_mutually_exclusive_group()
|
||||
manager_group.add_argument("--disable-manager-ui", action="store_true", help="Disables only the ComfyUI-Manager UI and endpoints. Scheduled installations and similar background tasks will still operate.")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager")
|
||||
manager_group.add_argument("--enable-manager-legacy-ui", action="store_true", help="Enables the legacy UI of ComfyUI-Manager. Implies --enable-manager.")
|
||||
|
||||
|
||||
vram_group = parser.add_mutually_exclusive_group()
|
||||
@ -166,6 +167,8 @@ class PerformanceFeature(enum.Enum):
|
||||
|
||||
parser.add_argument("--fast", nargs="*", type=PerformanceFeature, help="Enable some untested and potentially quality deteriorating optimizations. This is used to test new features so using it might crash your comfyui. --fast with no arguments enables everything. You can pass a list specific optimizations if you only want to enable specific ones. Current valid optimizations: {}".format(" ".join(map(lambda c: c.value, PerformanceFeature))))
|
||||
|
||||
parser.add_argument("--debug-hang", action="store_true", help="Enable stack trace dumps on Ctrl-C for debugging hangs.")
|
||||
|
||||
parser.add_argument("--disable-pinned-memory", action="store_true", help="Disable pinned memory use.")
|
||||
|
||||
parser.add_argument("--mmap-torch-files", action="store_true", help="Use mmap when loading ckpt/pt files.")
|
||||
@ -247,6 +250,9 @@ else:
|
||||
if args.cache_ram is not None and len(args.cache_ram) > 2:
|
||||
parser.error("--cache-ram accepts at most two values: active GB and inactive GB")
|
||||
|
||||
if args.high_ram:
|
||||
args.cache_classic = True
|
||||
|
||||
if args.windows_standalone_build:
|
||||
args.auto_launch = True
|
||||
|
||||
@ -256,6 +262,10 @@ if args.disable_auto_launch:
|
||||
if args.force_fp16:
|
||||
args.fp16_unet = True
|
||||
|
||||
# '--enable-manager-legacy-ui' is meaningless unless the manager is enabled, so imply '--enable-manager'.
|
||||
if args.enable_manager_legacy_ui:
|
||||
args.enable_manager = True
|
||||
|
||||
|
||||
# '--fast' is not provided, use an empty set
|
||||
if args.fast is None:
|
||||
|
||||
@ -1,7 +1,13 @@
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.text_encoders.bert import BertAttention
|
||||
import comfy.model_management
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from comfy.ldm.depth_anything_3.reference_view_selector import (
|
||||
select_reference_view, reorder_by_reference, restore_original_order,
|
||||
THRESH_FOR_REF_SELECTION,
|
||||
)
|
||||
|
||||
|
||||
class Dino2AttentionOutput(torch.nn.Module):
|
||||
@ -14,13 +20,41 @@ class Dino2AttentionOutput(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2AttentionBlock(torch.nn.Module):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations):
|
||||
def __init__(self, embed_dim, heads, layer_norm_eps, dtype, device, operations,
|
||||
qk_norm=False):
|
||||
super().__init__()
|
||||
self.heads = heads
|
||||
self.head_dim = embed_dim // heads
|
||||
self.attention = BertAttention(embed_dim, heads, dtype, device, operations)
|
||||
self.output = Dino2AttentionOutput(embed_dim, embed_dim, layer_norm_eps, dtype, device, operations)
|
||||
if qk_norm:
|
||||
self.q_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
|
||||
self.k_norm = operations.LayerNorm(self.head_dim, dtype=dtype, device=device)
|
||||
else:
|
||||
self.q_norm = None
|
||||
self.k_norm = None
|
||||
|
||||
def forward(self, x, mask, optimized_attention):
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
def forward(self, x, mask, optimized_attention, pos=None, rope=None):
|
||||
# Fast path used by the existing CLIP-vision DINOv2 (no DA3 extensions).
|
||||
if self.q_norm is None and rope is None:
|
||||
return self.output(self.attention(x, mask, optimized_attention))
|
||||
|
||||
# DA3 path: do QKV manually so we can apply per-head QK-norm and 2D RoPE.
|
||||
attn = self.attention
|
||||
B, N, C = x.shape
|
||||
h = self.heads
|
||||
d = self.head_dim
|
||||
q = attn.query(x).view(B, N, h, d).transpose(1, 2)
|
||||
k = attn.key(x).view(B, N, h, d).transpose(1, 2)
|
||||
v = attn.value(x).view(B, N, h, d).transpose(1, 2)
|
||||
if self.q_norm is not None:
|
||||
q = self.q_norm(q)
|
||||
k = self.k_norm(k)
|
||||
if rope is not None and pos is not None:
|
||||
q = rope(q, pos)
|
||||
k = rope(k, pos)
|
||||
out = optimized_attention(q, k, v, h, mask=mask, skip_reshape=True)
|
||||
return self.output(out)
|
||||
|
||||
|
||||
class LayerScale(torch.nn.Module):
|
||||
@ -64,9 +98,11 @@ class SwiGLUFFN(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Block(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn,
|
||||
qk_norm=False):
|
||||
super().__init__()
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations)
|
||||
self.attention = Dino2AttentionBlock(dim, num_heads, layer_norm_eps, dtype, device, operations,
|
||||
qk_norm=qk_norm)
|
||||
self.layer_scale1 = LayerScale(dim, dtype, device, operations)
|
||||
self.layer_scale2 = LayerScale(dim, dtype, device, operations)
|
||||
if use_swiglu_ffn:
|
||||
@ -76,19 +112,90 @@ class Dino2Block(torch.nn.Module):
|
||||
self.norm1 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
self.norm2 = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x, optimized_attention):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), None, optimized_attention))
|
||||
def forward(self, x, optimized_attention, pos=None, rope=None, attn_mask=None):
|
||||
x = x + self.layer_scale1(self.attention(self.norm1(x), attn_mask, optimized_attention,
|
||||
pos=pos, rope=rope))
|
||||
x = x + self.layer_scale2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn):
|
||||
# -----------------------------------------------------------------------------
|
||||
# 2D Rotary position embedding (DA3 extension)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _PositionGetter:
|
||||
"""Cache (h, w) -> flat (y, x) position grid used to feed ``rope``."""
|
||||
|
||||
def __init__(self):
|
||||
self._cache: dict = {}
|
||||
|
||||
def __call__(self, batch_size: int, height: int, width: int, device) -> torch.Tensor:
|
||||
key = (height, width, device)
|
||||
if key not in self._cache:
|
||||
y = torch.arange(height, device=device)
|
||||
x = torch.arange(width, device=device)
|
||||
self._cache[key] = torch.cartesian_prod(y, x)
|
||||
cached = self._cache[key]
|
||||
return cached.view(1, height * width, 2).expand(batch_size, -1, -1).clone()
|
||||
|
||||
|
||||
class RotaryPositionEmbedding2D(torch.nn.Module):
|
||||
"""2D RoPE used by DA3-Small/Base. No learnable parameters."""
|
||||
|
||||
def __init__(self, frequency: float = 100.0):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([Dino2Block(dim, num_heads, layer_norm_eps, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
for _ in range(num_layers)])
|
||||
self.base_frequency = frequency
|
||||
self._freq_cache: dict = {}
|
||||
|
||||
def _components(self, dim: int, seq_len: int, device, dtype):
|
||||
key = (dim, seq_len, device, dtype)
|
||||
if key not in self._freq_cache:
|
||||
exp = torch.arange(0, dim, 2, device=device).float() / dim
|
||||
inv_freq = 1.0 / (self.base_frequency ** exp)
|
||||
pos = torch.arange(seq_len, device=device, dtype=inv_freq.dtype)
|
||||
ang = torch.einsum("i,j->ij", pos, inv_freq)
|
||||
ang = ang.to(dtype)
|
||||
ang = torch.cat((ang, ang), dim=-1)
|
||||
self._freq_cache[key] = (ang.cos().to(dtype), ang.sin().to(dtype))
|
||||
return self._freq_cache[key]
|
||||
|
||||
@staticmethod
|
||||
def _rotate(x: torch.Tensor) -> torch.Tensor:
|
||||
d = x.shape[-1]
|
||||
x1, x2 = x[..., : d // 2], x[..., d // 2:]
|
||||
return torch.cat((-x2, x1), dim=-1)
|
||||
|
||||
def _apply_1d(self, tokens, positions, cos_c, sin_c):
|
||||
cos = F.embedding(positions, cos_c)[:, None, :, :]
|
||||
sin = F.embedding(positions, sin_c)[:, None, :, :]
|
||||
return (tokens * cos) + (self._rotate(tokens) * sin)
|
||||
|
||||
def forward(self, tokens: torch.Tensor, positions: torch.Tensor) -> torch.Tensor:
|
||||
feature_dim = tokens.size(-1) // 2
|
||||
max_pos = int(positions.max()) + 1
|
||||
cos_c, sin_c = self._components(feature_dim, max_pos, tokens.device, tokens.dtype)
|
||||
v, h = tokens.chunk(2, dim=-1)
|
||||
v = self._apply_1d(v, positions[..., 0], cos_c, sin_c)
|
||||
h = self._apply_1d(h, positions[..., 1], cos_c, sin_c)
|
||||
return torch.cat((v, h), dim=-1)
|
||||
|
||||
|
||||
class Dino2Encoder(torch.nn.Module):
|
||||
def __init__(self, dim, num_heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn,
|
||||
qknorm_start: int = -1):
|
||||
super().__init__()
|
||||
self.layer = torch.nn.ModuleList([
|
||||
Dino2Block(
|
||||
dim, num_heads, layer_norm_eps, dtype, device, operations,
|
||||
use_swiglu_ffn=use_swiglu_ffn,
|
||||
qk_norm=(qknorm_start != -1 and i >= qknorm_start),
|
||||
)
|
||||
for i in range(num_layers)
|
||||
])
|
||||
|
||||
def forward(self, x, intermediate_output=None):
|
||||
# Backward-compat path used by ``ClipVisionModel`` (no DA3 extensions).
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
if intermediate_output is not None:
|
||||
@ -122,16 +229,27 @@ class Dino2PatchEmbeddings(torch.nn.Module):
|
||||
|
||||
|
||||
class Dino2Embeddings(torch.nn.Module):
|
||||
def __init__(self, dim, dtype, device, operations):
|
||||
def __init__(self, dim, dtype, device, operations,
|
||||
patch_size: int = 14, image_size: int = 518,
|
||||
use_mask_token: bool = True,
|
||||
num_camera_tokens: int = 0):
|
||||
super().__init__()
|
||||
patch_size = 14
|
||||
image_size = 518
|
||||
self.patch_size = patch_size
|
||||
self.image_size = image_size
|
||||
|
||||
self.patch_embeddings = Dino2PatchEmbeddings(dim, patch_size=patch_size, image_size=image_size, dtype=dtype, device=device, operations=operations)
|
||||
self.position_embeddings = torch.nn.Parameter(torch.empty(1, (image_size // patch_size) ** 2 + 1, dim, dtype=dtype, device=device))
|
||||
self.cls_token = torch.nn.Parameter(torch.empty(1, 1, dim, dtype=dtype, device=device)) # mask_token is a pre-training param, kept only so strict loading accepts the key.
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
if use_mask_token:
|
||||
self.mask_token = torch.nn.Parameter(torch.empty(1, dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.mask_token = None
|
||||
if num_camera_tokens > 0:
|
||||
# DA3 stores (ref_token, src_token) pairs that get injected at the
|
||||
# alt-attn boundary; see ``Dinov2Model._inject_camera_token``.
|
||||
self.camera_token = torch.nn.Parameter(torch.empty(1, num_camera_tokens, dim, dtype=dtype, device=device))
|
||||
else:
|
||||
self.camera_token = None
|
||||
|
||||
def interpolate_pos_encoding(self, x, h_pixels, w_pixels):
|
||||
pos_embed = comfy.model_management.cast_to_device(self.position_embeddings, x.device, torch.float32)
|
||||
@ -140,12 +258,22 @@ class Dino2Embeddings(torch.nn.Module):
|
||||
patch_pos = pos_embed[:, 1:]
|
||||
N = patch_pos.shape[1]
|
||||
M = int(N ** 0.5)
|
||||
assert N == M * M, f"DINOv2 position grid must be square, got N={N} patches (sqrt={M})"
|
||||
h0 = h_pixels // self.patch_size
|
||||
w0 = w_pixels // self.patch_size
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M) # +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
# +0.1 matches upstream DINOv2's FP-rounding workaround so the interpolate output size lands on (h0, w0).
|
||||
# scale_factor is (height_scale, width_scale) -- height MUST come first;
|
||||
# swapping these only happens to work for square inputs and breaks
|
||||
# non-square paths like DA3-Small / DA3-Base multi-view.
|
||||
scale_factor = ((h0 + 0.1) / M, (w0 + 0.1) / M)
|
||||
|
||||
patch_pos = patch_pos.reshape(1, M, M, -1).permute(0, 3, 1, 2)
|
||||
patch_pos = torch.nn.functional.interpolate(patch_pos, scale_factor=scale_factor, mode="bicubic", antialias=False)
|
||||
assert (h0, w0) == patch_pos.shape[-2:], (
|
||||
f"Interpolated pos-embed grid {tuple(patch_pos.shape[-2:])} does not match "
|
||||
f"target patch grid ({h0}, {w0}) for input {h_pixels}x{w_pixels} (patch_size={self.patch_size}); "
|
||||
f"check scale_factor axis order and +0.1 rounding workaround"
|
||||
)
|
||||
patch_pos = patch_pos.permute(0, 2, 3, 1).flatten(1, 2)
|
||||
return torch.cat((class_pos, patch_pos), dim=1).to(x.dtype)
|
||||
|
||||
@ -168,12 +296,51 @@ class Dinov2Model(torch.nn.Module):
|
||||
heads = config_dict["num_attention_heads"]
|
||||
layer_norm_eps = config_dict["layer_norm_eps"]
|
||||
use_swiglu_ffn = config_dict["use_swiglu_ffn"]
|
||||
patch_size = config_dict.get("patch_size", 14)
|
||||
image_size = config_dict.get("image_size", 518)
|
||||
use_mask_token = config_dict.get("use_mask_token", True)
|
||||
|
||||
self.embeddings = Dino2Embeddings(dim, dtype, device, operations)
|
||||
self.encoder = Dino2Encoder(dim, heads, layer_norm_eps, num_layers, dtype, device, operations, use_swiglu_ffn = use_swiglu_ffn)
|
||||
# DA3 extensions (all default to disabled).
|
||||
self.alt_start = config_dict.get("alt_start", -1)
|
||||
self.qknorm_start = config_dict.get("qknorm_start", -1)
|
||||
self.rope_start = config_dict.get("rope_start", -1)
|
||||
self.cat_token = config_dict.get("cat_token", False)
|
||||
rope_freq = config_dict.get("rope_freq", 100.0)
|
||||
|
||||
self.embed_dim = dim
|
||||
self.patch_size = patch_size
|
||||
self.num_register_tokens = 0
|
||||
self.patch_start_idx = 1
|
||||
|
||||
if self.rope_start != -1 and rope_freq > 0:
|
||||
self.rope = RotaryPositionEmbedding2D(frequency=rope_freq)
|
||||
self._position_getter = _PositionGetter()
|
||||
else:
|
||||
self.rope = None
|
||||
self._position_getter = None
|
||||
|
||||
# camera_token shape: (1, 2, dim) -> (ref_token, src_token).
|
||||
num_cam_tokens = 2 if self.alt_start != -1 else 0
|
||||
|
||||
self.embeddings = Dino2Embeddings(
|
||||
dim, dtype, device, operations,
|
||||
patch_size=patch_size, image_size=image_size,
|
||||
use_mask_token=use_mask_token, num_camera_tokens=num_cam_tokens,
|
||||
)
|
||||
self.encoder = Dino2Encoder(
|
||||
dim, heads, layer_norm_eps, num_layers, dtype, device, operations,
|
||||
use_swiglu_ffn=use_swiglu_ffn,
|
||||
qknorm_start=self.qknorm_start,
|
||||
)
|
||||
self.layernorm = operations.LayerNorm(dim, eps=layer_norm_eps, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, pixel_values, attention_mask=None, intermediate_output=None):
|
||||
if self.alt_start != -1:
|
||||
raise RuntimeError(
|
||||
"Dinov2Model.forward() is the backward-compatible CLIP-vision path and does not "
|
||||
"apply DA3 extensions (RoPE, alternating attention, camera-token injection). "
|
||||
"Use get_intermediate_layers_da3() for Depth Anything 3 models."
|
||||
)
|
||||
x = self.embeddings(pixel_values)
|
||||
x, i = self.encoder(x, intermediate_output=intermediate_output)
|
||||
x = self.layernorm(x)
|
||||
@ -181,6 +348,7 @@ class Dinov2Model(torch.nn.Module):
|
||||
return x, i, pooled_output, None
|
||||
|
||||
def get_intermediate_layers(self, pixel_values, indices, apply_norm=True):
|
||||
"""Single-view multi-layer feature extraction."""
|
||||
x = self.embeddings(pixel_values)
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
n_layers = len(self.encoder.layer)
|
||||
@ -197,3 +365,132 @@ class Dinov2Model(torch.nn.Module):
|
||||
if i >= max_idx:
|
||||
break
|
||||
return [cache[i] for i in resolved]
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Depth Anything 3 forward
|
||||
# ------------------------------------------------------------------
|
||||
def _prepare_rope_positions(self, B, S, H, W, device):
|
||||
if self.rope is None:
|
||||
return None, None
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
pos = self._position_getter(B * S, ph, pw, device=device)
|
||||
# Shift so the cls/cam token at position 0 is reserved for "no diff".
|
||||
pos = pos + 1
|
||||
cls_pos = torch.zeros(B * S, self.patch_start_idx, 2, device=device, dtype=pos.dtype)
|
||||
# Per-view local: real grid positions for patches, 0 for cls token.
|
||||
pos_local = torch.cat([cls_pos, pos], dim=1)
|
||||
# Global (across views): same grid positions; cls token still at 0,
|
||||
# but patches share the same positions in every view.
|
||||
pos_global = torch.cat([cls_pos, torch.zeros_like(pos) + 1], dim=1)
|
||||
return pos_local, pos_global
|
||||
|
||||
def _inject_camera_token(self, x: torch.Tensor, B: int, S: int, cam_token: "torch.Tensor | None") -> torch.Tensor:
|
||||
# x: (B, S, N, C). Replace token at index 0 with the camera token.
|
||||
if cam_token is not None:
|
||||
inj = cam_token
|
||||
else:
|
||||
ct = comfy.model_management.cast_to_device(self.embeddings.camera_token, x.device, x.dtype)
|
||||
ref_token = ct[:, :1].expand(B, -1, -1)
|
||||
src_token = ct[:, 1:].expand(B, max(S - 1, 0), -1)
|
||||
inj = torch.cat([ref_token, src_token], dim=1)
|
||||
x = x.clone()
|
||||
x[:, :, 0] = inj
|
||||
return x
|
||||
|
||||
def get_intermediate_layers_da3(self, pixel_values, out_layers, cam_token=None, ref_view_strategy="saddle_balanced", export_feat_layers=None):
|
||||
"""Multi-view multi-layer feature extraction used by Depth Anything 3."""
|
||||
if pixel_values.ndim == 4:
|
||||
pixel_values = pixel_values.unsqueeze(1)
|
||||
assert pixel_values.ndim == 5 and pixel_values.shape[2] == 3, \
|
||||
f"expected (B,3,H,W) or (B,S,3,H,W); got {tuple(pixel_values.shape)}"
|
||||
B, S, _, H, W = pixel_values.shape
|
||||
|
||||
# Patch + cls + (interpolated) pos embed for each view.
|
||||
x = pixel_values.reshape(B * S, 3, H, W)
|
||||
x = self.embeddings(x) # (B*S, 1+N, C)
|
||||
x = x.reshape(B, S, x.shape[-2], x.shape[-1]) # (B, S, 1+N, C)
|
||||
|
||||
pos_local, pos_global = self._prepare_rope_positions(B, S, H, W, x.device)
|
||||
# optimized_attention is only used by blocks without QK-norm/RoPE
|
||||
# (vanilla DINOv2 path); enabling-aware blocks fall through to SDPA.
|
||||
optimized_attention = optimized_attention_for_device(x.device, False, small_input=True)
|
||||
|
||||
out_set = set(out_layers)
|
||||
export_set = set(export_feat_layers) if export_feat_layers else set()
|
||||
outputs: list[torch.Tensor] = []
|
||||
aux_outputs: list[torch.Tensor] = []
|
||||
local_x = x
|
||||
b_idx = None
|
||||
|
||||
|
||||
for i, blk in enumerate(self.encoder.layer):
|
||||
apply_rope = self.rope is not None and i >= self.rope_start
|
||||
block_rope = self.rope if apply_rope else None
|
||||
l_pos = pos_local if apply_rope else None
|
||||
g_pos = pos_global if apply_rope else None
|
||||
|
||||
# Reference-view selection threshold: matches the upstream constant
|
||||
# THRESH_FOR_REF_SELECTION = 3. Skipped when a user-supplied
|
||||
# cam_token is provided (camera info already pins the geometry).
|
||||
if (self.alt_start != -1 and i == self.alt_start - 1 and S >= THRESH_FOR_REF_SELECTION and cam_token is None):
|
||||
b_idx = select_reference_view(x, strategy=ref_view_strategy)
|
||||
x = reorder_by_reference(x, b_idx)
|
||||
local_x = reorder_by_reference(local_x, b_idx)
|
||||
|
||||
if self.alt_start != -1 and i == self.alt_start:
|
||||
x = self._inject_camera_token(x, B, S, cam_token)
|
||||
|
||||
if self.alt_start != -1 and i >= self.alt_start and (i % 2 == 1):
|
||||
# Global attention across views: flatten S into the seq dim.
|
||||
t = x.reshape(B, S * x.shape[-2], x.shape[-1])
|
||||
p = g_pos.reshape(B, S * g_pos.shape[-2], g_pos.shape[-1]) if g_pos is not None else None
|
||||
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
|
||||
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
|
||||
else:
|
||||
# Per-view local attention.
|
||||
t = x.reshape(B * S, x.shape[-2], x.shape[-1])
|
||||
p = l_pos.reshape(B * S, l_pos.shape[-2], l_pos.shape[-1]) if l_pos is not None else None
|
||||
t = blk(t, optimized_attention=optimized_attention, pos=p, rope=block_rope)
|
||||
x = t.reshape(B, S, x.shape[-2], x.shape[-1])
|
||||
local_x = x
|
||||
|
||||
if i in out_set:
|
||||
if self.cat_token:
|
||||
out_x = torch.cat([local_x, x], dim=-1)
|
||||
else:
|
||||
out_x = x
|
||||
# Restore original view order on the way out so heads see views
|
||||
# in the user's expected order.
|
||||
if b_idx is not None and self.alt_start != -1:
|
||||
out_x = restore_original_order(out_x, b_idx)
|
||||
outputs.append(out_x)
|
||||
|
||||
if i in export_set:
|
||||
aux = x
|
||||
if b_idx is not None and self.alt_start != -1:
|
||||
aux = restore_original_order(aux, b_idx)
|
||||
aux_outputs.append(aux)
|
||||
|
||||
# Apply final norm. When cat_token is set, only the right half
|
||||
# ("global" features) is normalised; the left half is left as-is to
|
||||
# match the upstream DA3 head signature.
|
||||
normed: list[torch.Tensor] = []
|
||||
cls_tokens: list[torch.Tensor] = []
|
||||
for out_x in outputs:
|
||||
cls_tokens.append(out_x[:, :, 0])
|
||||
if out_x.shape[-1] == self.embed_dim:
|
||||
normed.append(self.layernorm(out_x))
|
||||
elif out_x.shape[-1] == self.embed_dim * 2:
|
||||
left = out_x[..., :self.embed_dim]
|
||||
right = self.layernorm(out_x[..., self.embed_dim:])
|
||||
normed.append(torch.cat([left, right], dim=-1))
|
||||
else:
|
||||
raise ValueError(f"Unexpected token width: {out_x.shape[-1]}")
|
||||
|
||||
# Drop cls/cam token from the patch sequence.
|
||||
normed = [o[..., 1 + self.num_register_tokens:, :] for o in normed]
|
||||
|
||||
# Final layernorm + drop cls token from auxiliary features too.
|
||||
aux_normed = [self.layernorm(o)[..., 1 + self.num_register_tokens:, :]
|
||||
for o in aux_outputs]
|
||||
return list(zip(normed, cls_tokens)), aux_normed
|
||||
|
||||
25
comfy/ldm/colormap.py
Normal file
25
comfy/ldm/colormap.py
Normal file
@ -0,0 +1,25 @@
|
||||
"""Colormap utilities for depth and geometry visualisation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def turbo(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Anton Mikhailov polynomial approximation of the Turbo colormap.
|
||||
|
||||
Args:
|
||||
x: Float tensor with values in [0, 1].
|
||||
|
||||
Returns:
|
||||
RGB tensor of the same shape as ``x`` with a trailing size-3 dimension.
|
||||
"""
|
||||
x = x.clamp(0.0, 1.0)
|
||||
x2 = x * x
|
||||
x3 = x2 * x
|
||||
x4 = x2 * x2
|
||||
x5 = x4 * x
|
||||
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
|
||||
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
|
||||
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
|
||||
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
|
||||
177
comfy/ldm/depth_anything_3/camera.py
Normal file
177
comfy/ldm/depth_anything_3/camera.py
Normal file
@ -0,0 +1,177 @@
|
||||
"""Camera-token encoder and decoder for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention_for_device
|
||||
from .transform import affine_inverse, extri_intri_to_pose_encoding
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------
|
||||
# Building blocks (mirror depth_anything_3.model.utils.{attention,block})
|
||||
# -----------------------------------------------------------------------
|
||||
|
||||
|
||||
class _Mlp(nn.Module):
|
||||
"""Standard 2-layer MLP with GELU. Matches upstream ``utils.attention.Mlp``."""
|
||||
|
||||
def __init__(self, in_features, hidden_features=None, out_features=None, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
self.fc1 = operations.Linear(in_features, hidden_features, bias=True, device=device, dtype=dtype)
|
||||
self.fc2 = operations.Linear(hidden_features, out_features, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
return self.fc2(F.gelu(self.fc1(x)))
|
||||
|
||||
|
||||
class _LayerScale(nn.Module):
|
||||
"""Per-channel learnable scaling. Matches upstream LayerScale."""
|
||||
|
||||
def __init__(self, dim, *, device=None, dtype=None):
|
||||
super().__init__()
|
||||
self.gamma = nn.Parameter(torch.empty(dim, device=device, dtype=dtype))
|
||||
|
||||
def forward(self, x):
|
||||
return x * self.gamma.to(dtype=x.dtype, device=x.device)
|
||||
|
||||
|
||||
class _Attention(nn.Module):
|
||||
""" Self-attention with fused QKV projection. Mirrors upstream utils.attention.Attention;
|
||||
Layout matches the HF safetensors (attn.qkv.{weight,bias} and attn.proj.{weight,bias})."""
|
||||
|
||||
def __init__(self, dim, num_heads, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.qkv = operations.Linear(dim, dim * 3, bias=True, device=device, dtype=dtype)
|
||||
self.proj = operations.Linear(dim, dim, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, x):
|
||||
B, N, C = x.shape
|
||||
qkv = self.qkv(x).reshape(B, N, 3, C)
|
||||
q, k, v = qkv.unbind(2) # each (B, N, C)
|
||||
attn_fn = optimized_attention_for_device(x.device, small_input=True)
|
||||
out = attn_fn(q, k, v, heads=self.num_heads)
|
||||
return self.proj(out)
|
||||
|
||||
|
||||
class _Block(nn.Module):
|
||||
"""Pre-norm transformer block with LayerScale. Used by :class:CameraEnc. Layout follows upstream utils.block.Block."""
|
||||
|
||||
def __init__(self, dim, num_heads, mlp_ratio=4, init_values=0.01, *, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.norm1 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.attn = _Attention(dim, num_heads, device=device, dtype=dtype, operations=operations)
|
||||
self.ls1 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
|
||||
self.norm2 = operations.LayerNorm(dim, device=device, dtype=dtype)
|
||||
self.mlp = _Mlp(in_features=dim, hidden_features=int(dim * mlp_ratio), device=device, dtype=dtype, operations=operations)
|
||||
self.ls2 = _LayerScale(dim, device=device, dtype=dtype) if init_values else nn.Identity()
|
||||
|
||||
def forward(self, x):
|
||||
x = x + self.ls1(self.attn(self.norm1(x)))
|
||||
x = x + self.ls2(self.mlp(self.norm2(x)))
|
||||
return x
|
||||
|
||||
|
||||
class CameraEnc(nn.Module):
|
||||
"""Encode per-view (extrinsics, intrinsics) into a camera token.
|
||||
|
||||
Maps a 9-D pose-encoding vector through a small MLP up to the backbone's
|
||||
``embed_dim``, then runs ``trunk_depth`` transformer blocks. The output
|
||||
has shape ``(B, S, embed_dim)`` and is injected at block ``alt_start``
|
||||
of the DINOv2 backbone in place of the cls token.
|
||||
|
||||
Parameters mirror the upstream ``cam_enc.py`` so HF weights load directly.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_out: int = 1024,
|
||||
dim_in: int = 9,
|
||||
trunk_depth: int = 4,
|
||||
target_dim: int = 9,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: int = 4,
|
||||
init_values: float = 0.01,
|
||||
*,
|
||||
device=None, dtype=None, operations=None,
|
||||
**_kwargs,
|
||||
):
|
||||
super().__init__()
|
||||
self.target_dim = target_dim
|
||||
self.trunk_depth = trunk_depth
|
||||
self.trunk = nn.Sequential(*[
|
||||
_Block(dim_out, num_heads=num_heads, mlp_ratio=mlp_ratio,
|
||||
init_values=init_values,
|
||||
device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(trunk_depth)
|
||||
])
|
||||
self.token_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
|
||||
self.trunk_norm = operations.LayerNorm(dim_out, device=device, dtype=dtype)
|
||||
self.pose_branch = _Mlp(
|
||||
in_features=dim_in,
|
||||
hidden_features=dim_out // 2,
|
||||
out_features=dim_out,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
|
||||
def forward(self, extrinsics: torch.Tensor, intrinsics: torch.Tensor,
|
||||
image_size_hw) -> torch.Tensor:
|
||||
"""Encode camera parameters into ``(B, S, dim_out)`` tokens."""
|
||||
c2ws = affine_inverse(extrinsics)
|
||||
pose_encoding = extri_intri_to_pose_encoding(c2ws, intrinsics, image_size_hw)
|
||||
tokens = self.pose_branch(pose_encoding.to(self.pose_branch.fc1.weight.dtype))
|
||||
tokens = self.token_norm(tokens)
|
||||
tokens = self.trunk(tokens)
|
||||
tokens = self.trunk_norm(tokens)
|
||||
return tokens
|
||||
|
||||
|
||||
class CameraDec(nn.Module):
|
||||
"""Decode the final cam token into a 9-D pose encoding.
|
||||
|
||||
Output layout: ``[T(3), quat_xyzw(4), fov_h, fov_w]``. The translation is
|
||||
always predicted by the network; the quaternion and FoV can either be
|
||||
predicted or supplied via ``camera_encoding`` (used at training time
|
||||
when GT cameras are available -- not exercised at inference here).
|
||||
|
||||
Parameters mirror the upstream ``cam_dec.py`` so HF weights load directly.
|
||||
"""
|
||||
|
||||
def __init__(self, dim_in: int = 1536,
|
||||
*, device=None, dtype=None, operations=None, **_kwargs):
|
||||
super().__init__()
|
||||
d = dim_in
|
||||
self.backbone = nn.Sequential(
|
||||
operations.Linear(d, d, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
operations.Linear(d, d, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
)
|
||||
self.fc_t = operations.Linear(d, 3, device=device, dtype=dtype)
|
||||
self.fc_qvec = operations.Linear(d, 4, device=device, dtype=dtype)
|
||||
self.fc_fov = nn.Sequential(
|
||||
operations.Linear(d, 2, device=device, dtype=dtype),
|
||||
nn.ReLU(),
|
||||
)
|
||||
|
||||
def forward(self, feat: torch.Tensor,
|
||||
camera_encoding: "torch.Tensor | None" = None) -> torch.Tensor:
|
||||
"""Decode ``(B, N, dim_in)`` cam tokens into ``(B, N, 9)`` pose enc."""
|
||||
B, N = feat.shape[:2]
|
||||
feat = feat.reshape(B * N, -1)
|
||||
feat = self.backbone(feat)
|
||||
out_t = self.fc_t(feat.float()).reshape(B, N, 3)
|
||||
if camera_encoding is None:
|
||||
out_qvec = self.fc_qvec(feat.float()).reshape(B, N, 4)
|
||||
out_fov = self.fc_fov(feat.float()).reshape(B, N, 2)
|
||||
else:
|
||||
out_qvec = camera_encoding[..., 3:7]
|
||||
out_fov = camera_encoding[..., -2:]
|
||||
return torch.cat([out_t, out_qvec, out_fov], dim=-1)
|
||||
489
comfy/ldm/depth_anything_3/dpt.py
Normal file
489
comfy/ldm/depth_anything_3/dpt.py
Normal file
@ -0,0 +1,489 @@
|
||||
"""DPT / DualDPT heads for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import List, Optional, Sequence, Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Permute(nn.Module):
|
||||
def __init__(self, dims: Tuple[int, ...]):
|
||||
super().__init__()
|
||||
self.dims = dims
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
return x.permute(*self.dims)
|
||||
|
||||
|
||||
def _custom_interpolate(
|
||||
x: torch.Tensor,
|
||||
size: Optional[Tuple[int, int]] = None,
|
||||
scale_factor: Optional[float] = None,
|
||||
mode: str = "bilinear",
|
||||
align_corners: bool = True,
|
||||
) -> torch.Tensor:
|
||||
if size is None:
|
||||
assert scale_factor is not None
|
||||
size = (int(x.shape[-2] * scale_factor), int(x.shape[-1] * scale_factor))
|
||||
INT_MAX = 1610612736
|
||||
total = size[0] * size[1] * x.shape[0] * x.shape[1]
|
||||
if total > INT_MAX:
|
||||
chunks = torch.chunk(x, chunks=(total // INT_MAX) + 1, dim=0)
|
||||
outs = [F.interpolate(c, size=size, mode=mode, align_corners=align_corners) for c in chunks]
|
||||
return torch.cat(outs, dim=0).contiguous()
|
||||
return F.interpolate(x, size=size, mode=mode, align_corners=align_corners)
|
||||
|
||||
|
||||
def _create_uv_grid(width: int, height: int, aspect_ratio: float, dtype, device) -> torch.Tensor:
|
||||
"""Normalised UV grid spanning (-x_span, -y_span)..(x_span, y_span)."""
|
||||
diag_factor = (aspect_ratio ** 2 + 1.0) ** 0.5
|
||||
span_x = aspect_ratio / diag_factor
|
||||
span_y = 1.0 / diag_factor
|
||||
left_x = -span_x * (width - 1) / width
|
||||
right_x = span_x * (width - 1) / width
|
||||
top_y = -span_y * (height - 1) / height
|
||||
bottom_y = span_y * (height - 1) / height
|
||||
x_coords = torch.linspace(left_x, right_x, steps=width, dtype=dtype, device=device)
|
||||
y_coords = torch.linspace(top_y, bottom_y, steps=height, dtype=dtype, device=device)
|
||||
uu, vv = torch.meshgrid(x_coords, y_coords, indexing="xy")
|
||||
return torch.stack((uu, vv), dim=-1) # (H, W, 2)
|
||||
|
||||
|
||||
def _make_sincos_pos_embed(embed_dim: int, pos: torch.Tensor, omega_0: float = 100.0) -> torch.Tensor:
|
||||
omega = torch.arange(embed_dim // 2, dtype=torch.float32, device=pos.device)
|
||||
omega = 1.0 / omega_0 ** (omega / (embed_dim / 2.0))
|
||||
pos = pos.reshape(-1)
|
||||
out = torch.einsum("m,d->md", pos, omega)
|
||||
return torch.cat([out.sin(), out.cos()], dim=1).float()
|
||||
|
||||
|
||||
def _position_grid_to_embed(pos_grid: torch.Tensor, embed_dim: int, omega_0: float = 100.0) -> torch.Tensor:
|
||||
H, W, _ = pos_grid.shape
|
||||
pos_flat = pos_grid.reshape(-1, 2)
|
||||
emb_x = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 0], omega_0=omega_0)
|
||||
emb_y = _make_sincos_pos_embed(embed_dim // 2, pos_flat[:, 1], omega_0=omega_0)
|
||||
emb = torch.cat([emb_x, emb_y], dim=-1)
|
||||
return emb.view(H, W, embed_dim)
|
||||
|
||||
|
||||
def _add_pos_embed(x: torch.Tensor, W: int, H: int, ratio: float = 0.1) -> torch.Tensor:
|
||||
"""Stateless UV positional embedding added to a feature map (B, C, h, w)."""
|
||||
pw, ph = x.shape[-1], x.shape[-2]
|
||||
pe = _create_uv_grid(pw, ph, aspect_ratio=W / H, dtype=x.dtype, device=x.device)
|
||||
pe = _position_grid_to_embed(pe, x.shape[1]) * ratio
|
||||
pe = pe.permute(2, 0, 1)[None].expand(x.shape[0], -1, -1, -1).to(dtype=x.dtype)
|
||||
return x + pe
|
||||
|
||||
|
||||
def _apply_activation(x: torch.Tensor, activation: str) -> torch.Tensor:
|
||||
act = (activation or "linear").lower()
|
||||
if act == "exp":
|
||||
return torch.exp(x)
|
||||
if act == "expp1":
|
||||
return torch.exp(x) + 1
|
||||
if act == "expm1":
|
||||
return torch.expm1(x)
|
||||
if act == "relu":
|
||||
return torch.relu(x)
|
||||
if act == "sigmoid":
|
||||
return torch.sigmoid(x)
|
||||
if act == "softplus":
|
||||
return F.softplus(x)
|
||||
if act == "tanh":
|
||||
return torch.tanh(x)
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Fusion building blocks
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ResidualConvUnit(nn.Module):
|
||||
def __init__(self, features: int, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.conv1 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
|
||||
self.conv2 = operations.Conv2d(features, features, 3, 1, 1, bias=True, device=device, dtype=dtype)
|
||||
self.activation = nn.ReLU(inplace=False)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
out = self.activation(x)
|
||||
out = self.conv1(out)
|
||||
out = self.activation(out)
|
||||
out = self.conv2(out)
|
||||
return out + x
|
||||
|
||||
|
||||
class FeatureFusionBlock(nn.Module):
|
||||
def __init__(self, features: int, has_residual: bool = True, align_corners: bool = True, device=None, dtype=None, operations=None):
|
||||
super().__init__()
|
||||
self.align_corners = align_corners
|
||||
self.has_residual = has_residual
|
||||
if has_residual:
|
||||
self.resConfUnit1 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
|
||||
else:
|
||||
self.resConfUnit1 = None
|
||||
self.resConfUnit2 = ResidualConvUnit(features, device=device, dtype=dtype, operations=operations)
|
||||
self.out_conv = operations.Conv2d(features, features, 1, 1, 0, bias=True, device=device, dtype=dtype)
|
||||
|
||||
def forward(self, *xs: torch.Tensor, size: Optional[Tuple[int, int]] = None) -> torch.Tensor:
|
||||
y = xs[0]
|
||||
if self.has_residual and len(xs) > 1 and self.resConfUnit1 is not None:
|
||||
y = y + self.resConfUnit1(xs[1])
|
||||
y = self.resConfUnit2(y)
|
||||
if size is None:
|
||||
up_kwargs = {"scale_factor": 2.0}
|
||||
else:
|
||||
up_kwargs = {"size": size}
|
||||
y = _custom_interpolate(y, **up_kwargs, mode="bilinear", align_corners=self.align_corners)
|
||||
y = self.out_conv(y)
|
||||
return y
|
||||
|
||||
|
||||
class _Scratch(nn.Module):
|
||||
"""Container that mirrors upstream ``scratch`` attribute layout."""
|
||||
|
||||
|
||||
def _make_scratch(in_shape: List[int], out_shape: int, device=None, dtype=None, operations=None) -> _Scratch:
|
||||
scratch = _Scratch()
|
||||
scratch.layer1_rn = operations.Conv2d(in_shape[0], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer2_rn = operations.Conv2d(in_shape[1], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer3_rn = operations.Conv2d(in_shape[2], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
scratch.layer4_rn = operations.Conv2d(in_shape[3], out_shape, 3, 1, 1, bias=False, device=device, dtype=dtype)
|
||||
return scratch
|
||||
|
||||
|
||||
def _make_fusion_block(features: int, has_residual: bool = True, device=None, dtype=None, operations=None) -> FeatureFusionBlock:
|
||||
return FeatureFusionBlock(features, has_residual=has_residual, align_corners=True, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DPT (single head + optional sky head) -- used by DA3Mono/Metric
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DPT(nn.Module):
|
||||
"""Single-head DPT used by DA3Mono-Large and DA3Metric-Large."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 1,
|
||||
activation: str = "exp",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
pos_embed: bool = False,
|
||||
down_ratio: int = 1,
|
||||
head_name: str = "depth",
|
||||
use_sky_head: bool = True,
|
||||
sky_name: str = "sky",
|
||||
sky_activation: str = "relu",
|
||||
norm_type: str = "idt",
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.down_ratio = down_ratio
|
||||
self.head_main = head_name
|
||||
self.sky_name = sky_name
|
||||
self.out_dim = output_dim
|
||||
self.has_conf = output_dim > 1
|
||||
self.use_sky_head = use_sky_head
|
||||
self.sky_activation = sky_activation
|
||||
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
||||
|
||||
if norm_type == "layer":
|
||||
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
|
||||
else:
|
||||
self.norm = nn.Identity()
|
||||
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
self.scratch.output_conv1 = operations.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
if self.use_sky_head:
|
||||
self.scratch.sky_output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
# feats[i][0] is the patch-token tensor with shape (B, S, N_patch, C)
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
resized = []
|
||||
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
||||
x = feats_flat[take_idx][:, patch_start_idx:]
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
|
||||
x = self.projects[stage_idx](x)
|
||||
if self.pos_embed:
|
||||
x = _add_pos_embed(x, W, H)
|
||||
x = self.resize_layers[stage_idx](x)
|
||||
resized.append(x)
|
||||
|
||||
l1_rn = self.scratch.layer1_rn(resized[0])
|
||||
l2_rn = self.scratch.layer2_rn(resized[1])
|
||||
l3_rn = self.scratch.layer3_rn(resized[2])
|
||||
l4_rn = self.scratch.layer4_rn(resized[3])
|
||||
|
||||
out = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
||||
out = self.scratch.refinenet3(out, l3_rn, size=l2_rn.shape[2:])
|
||||
out = self.scratch.refinenet2(out, l2_rn, size=l1_rn.shape[2:])
|
||||
out = self.scratch.refinenet1(out, l1_rn)
|
||||
|
||||
h_out = int(ph * self.patch_size / self.down_ratio)
|
||||
w_out = int(pw * self.patch_size / self.down_ratio)
|
||||
|
||||
fused = self.scratch.output_conv1(out)
|
||||
fused = _custom_interpolate(fused, (h_out, w_out), mode="bilinear", align_corners=True)
|
||||
if self.pos_embed:
|
||||
fused = _add_pos_embed(fused, W, H)
|
||||
feat = fused
|
||||
|
||||
main_logits = self.scratch.output_conv2(feat)
|
||||
outs = {}
|
||||
if self.has_conf:
|
||||
fmap = main_logits.permute(0, 2, 3, 1)
|
||||
pred = _apply_activation(fmap[..., :-1], self.activation)
|
||||
conf = _apply_activation(fmap[..., -1], self.conf_activation)
|
||||
outs[self.head_main] = pred.squeeze(-1).view(B, S, *pred.shape[1:-1])
|
||||
outs[f"{self.head_main}_conf"] = conf.view(B, S, *conf.shape[1:])
|
||||
else:
|
||||
pred = _apply_activation(main_logits, self.activation)
|
||||
outs[self.head_main] = pred.squeeze(1).view(B, S, *pred.shape[2:])
|
||||
|
||||
if self.use_sky_head:
|
||||
sky_logits = self.scratch.sky_output_conv2(feat)
|
||||
if self.sky_activation.lower() == "sigmoid":
|
||||
sky = torch.sigmoid(sky_logits)
|
||||
elif self.sky_activation.lower() == "relu":
|
||||
sky = F.relu(sky_logits)
|
||||
else:
|
||||
sky = sky_logits
|
||||
outs[self.sky_name] = sky.squeeze(1).view(B, S, *sky.shape[2:])
|
||||
|
||||
return outs
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# DualDPT (depth + auxiliary "ray" head) -- used by DA3-Small / DA3-Base
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class DualDPT(nn.Module):
|
||||
"""Two-head DPT used by DA3-Small / DA3-Base."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim_in: int,
|
||||
patch_size: int = 14,
|
||||
output_dim: int = 2,
|
||||
activation: str = "exp",
|
||||
conf_activation: str = "expp1",
|
||||
features: int = 256,
|
||||
out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
pos_embed: bool = True,
|
||||
down_ratio: int = 1,
|
||||
aux_pyramid_levels: int = 4,
|
||||
aux_out1_conv_num: int = 5,
|
||||
head_names: Tuple[str, str] = ("depth", "ray"),
|
||||
device=None, dtype=None, operations=None,
|
||||
):
|
||||
super().__init__()
|
||||
self.patch_size = patch_size
|
||||
self.activation = activation
|
||||
self.conf_activation = conf_activation
|
||||
self.pos_embed = pos_embed
|
||||
self.down_ratio = down_ratio
|
||||
self.aux_levels = aux_pyramid_levels
|
||||
self.aux_out1_conv_num = aux_out1_conv_num
|
||||
self.head_main, self.head_aux = head_names
|
||||
self.intermediate_layer_idx: Tuple[int, int, int, int] = (0, 1, 2, 3)
|
||||
# Toggle the auxiliary ray branch at runtime. Default off (mono path).
|
||||
# DepthAnything3Net flips this on when running multi-view + ray-pose.
|
||||
self.enable_aux: bool = False
|
||||
|
||||
self.norm = operations.LayerNorm(dim_in, device=device, dtype=dtype)
|
||||
out_channels = list(out_channels)
|
||||
self.projects = nn.ModuleList([
|
||||
operations.Conv2d(dim_in, oc, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype)
|
||||
for oc in out_channels
|
||||
])
|
||||
self.resize_layers = nn.ModuleList([
|
||||
operations.ConvTranspose2d(out_channels[0], out_channels[0], kernel_size=4, stride=4, padding=0, device=device, dtype=dtype),
|
||||
operations.ConvTranspose2d(out_channels[1], out_channels[1], kernel_size=2, stride=2, padding=0, device=device, dtype=dtype),
|
||||
nn.Identity(),
|
||||
operations.Conv2d(out_channels[3], out_channels[3], kernel_size=3, stride=2, padding=1, device=device, dtype=dtype),
|
||||
])
|
||||
|
||||
self.scratch = _make_scratch(out_channels, features, device=device, dtype=dtype, operations=operations)
|
||||
# Main fusion chain
|
||||
self.scratch.refinenet1 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3 = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4 = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
# Auxiliary fusion chain (separate copies)
|
||||
self.scratch.refinenet1_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet2_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet3_aux = _make_fusion_block(features, device=device, dtype=dtype, operations=operations)
|
||||
self.scratch.refinenet4_aux = _make_fusion_block(features, has_residual=False, device=device, dtype=dtype, operations=operations)
|
||||
|
||||
head_features_1 = features
|
||||
head_features_2 = 32
|
||||
|
||||
# Main head neck + final projection
|
||||
self.scratch.output_conv1 = operations.Conv2d(
|
||||
head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1,
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
self.scratch.output_conv2 = nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, output_dim, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
# Aux pre-head per level (multi-level pyramid)
|
||||
self.scratch.output_conv1_aux = nn.ModuleList([
|
||||
self._make_aux_out1_block(head_features_1, device=device, dtype=dtype, operations=operations)
|
||||
for _ in range(self.aux_levels)
|
||||
])
|
||||
|
||||
# Aux final projection per level (includes LayerNorm permute path).
|
||||
ln_seq = [Permute((0, 2, 3, 1)),
|
||||
operations.LayerNorm(head_features_2, device=device, dtype=dtype),
|
||||
Permute((0, 3, 1, 2))]
|
||||
self.scratch.output_conv2_aux = nn.ModuleList([
|
||||
nn.Sequential(
|
||||
operations.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1, device=device, dtype=dtype),
|
||||
*ln_seq,
|
||||
nn.ReLU(inplace=False),
|
||||
operations.Conv2d(head_features_2, 7, kernel_size=1, stride=1, padding=0, device=device, dtype=dtype),
|
||||
)
|
||||
for _ in range(self.aux_levels)
|
||||
])
|
||||
|
||||
@staticmethod
|
||||
def _make_aux_out1_block(in_ch: int, *, device=None, dtype=None, operations=None) -> nn.Sequential:
|
||||
# aux_out1_conv_num=5 in all Apache-2.0 variants.
|
||||
return nn.Sequential(
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch // 2, in_ch, 3, 1, 1, device=device, dtype=dtype),
|
||||
operations.Conv2d(in_ch, in_ch // 2, 3, 1, 1, device=device, dtype=dtype),
|
||||
)
|
||||
|
||||
def forward(self, feats: List[torch.Tensor], H: int, W: int, patch_start_idx: int = 0, **_kwargs) -> dict:
|
||||
B, S, N, C = feats[0][0].shape
|
||||
feats_flat = [feat[0].reshape(B * S, N, C) for feat in feats]
|
||||
|
||||
ph, pw = H // self.patch_size, W // self.patch_size
|
||||
resized = []
|
||||
for stage_idx, take_idx in enumerate(self.intermediate_layer_idx):
|
||||
x = feats_flat[take_idx][:, patch_start_idx:]
|
||||
x = self.norm(x)
|
||||
x = x.permute(0, 2, 1).contiguous().reshape(B * S, C, ph, pw)
|
||||
x = self.projects[stage_idx](x)
|
||||
if self.pos_embed:
|
||||
x = _add_pos_embed(x, W, H)
|
||||
x = self.resize_layers[stage_idx](x)
|
||||
resized.append(x)
|
||||
|
||||
l1_rn = self.scratch.layer1_rn(resized[0])
|
||||
l2_rn = self.scratch.layer2_rn(resized[1])
|
||||
l3_rn = self.scratch.layer3_rn(resized[2])
|
||||
l4_rn = self.scratch.layer4_rn(resized[3])
|
||||
|
||||
# Main pyramid (output_conv1 is applied inside the upstream `_fuse`,
|
||||
# before interpolation -- replicate that order here).
|
||||
m = self.scratch.refinenet4(l4_rn, size=l3_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
a4 = self.scratch.refinenet4_aux(l4_rn, size=l3_rn.shape[2:])
|
||||
aux_pyr = [a4]
|
||||
m = self.scratch.refinenet3(m, l3_rn, size=l2_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet3_aux(aux_pyr[-1], l3_rn, size=l2_rn.shape[2:]))
|
||||
m = self.scratch.refinenet2(m, l2_rn, size=l1_rn.shape[2:])
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet2_aux(aux_pyr[-1], l2_rn, size=l1_rn.shape[2:]))
|
||||
m = self.scratch.refinenet1(m, l1_rn)
|
||||
if self.enable_aux:
|
||||
aux_pyr.append(self.scratch.refinenet1_aux(aux_pyr[-1], l1_rn))
|
||||
m = self.scratch.output_conv1(m)
|
||||
|
||||
h_out = int(ph * self.patch_size / self.down_ratio)
|
||||
w_out = int(pw * self.patch_size / self.down_ratio)
|
||||
|
||||
m = _custom_interpolate(m, (h_out, w_out), mode="bilinear", align_corners=True)
|
||||
if self.pos_embed:
|
||||
m = _add_pos_embed(m, W, H)
|
||||
main_logits = self.scratch.output_conv2(m)
|
||||
fmap = main_logits.permute(0, 2, 3, 1)
|
||||
depth_pred = _apply_activation(fmap[..., :-1], self.activation)
|
||||
depth_conf = _apply_activation(fmap[..., -1], self.conf_activation)
|
||||
|
||||
outs = {
|
||||
self.head_main: depth_pred.squeeze(-1).view(B, S, *depth_pred.shape[1:-1]),
|
||||
f"{self.head_main}_conf": depth_conf.view(B, S, *depth_conf.shape[1:]),
|
||||
}
|
||||
|
||||
if self.enable_aux:
|
||||
# Auxiliary "ray" head (multi-level inside) -- only the last level
|
||||
# is returned. Mirrors upstream ``DualDPT._fuse`` + ``_forward_impl``:
|
||||
# each aux pyramid level goes through ``output_conv1_aux[i]``
|
||||
# (5-layer conv stack that ends at ``features // 2`` channels),
|
||||
# then the last level optionally gets a pos-embed and finally
|
||||
# ``output_conv2_aux[-1]``.
|
||||
aux_processed = [
|
||||
self.scratch.output_conv1_aux[i](a) for i, a in enumerate(aux_pyr)
|
||||
]
|
||||
last_aux = aux_processed[-1]
|
||||
if self.pos_embed:
|
||||
last_aux = _add_pos_embed(last_aux, W, H)
|
||||
last_aux_logits = self.scratch.output_conv2_aux[-1](last_aux)
|
||||
fmap_last = last_aux_logits.permute(0, 2, 3, 1)
|
||||
# Channels: [ray(6), ray_conf(1)]; ray uses 'linear' activation.
|
||||
aux_pred = fmap_last[..., :-1]
|
||||
aux_conf = _apply_activation(fmap_last[..., -1], self.conf_activation)
|
||||
outs[self.head_aux] = aux_pred.view(B, S, *aux_pred.shape[1:])
|
||||
outs[f"{self.head_aux}_conf"] = aux_conf.view(B, S, *aux_conf.shape[1:])
|
||||
|
||||
return outs
|
||||
236
comfy/ldm/depth_anything_3/model.py
Normal file
236
comfy/ldm/depth_anything_3/model.py
Normal file
@ -0,0 +1,236 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Dict, Optional, Sequence
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from comfy.image_encoders.dino2 import Dinov2Model
|
||||
|
||||
from .camera import CameraDec, CameraEnc
|
||||
from .dpt import DPT, DualDPT
|
||||
from .ray_pose import get_extrinsic_from_camray
|
||||
from .transform import affine_inverse, pose_encoding_to_extri_intri
|
||||
|
||||
|
||||
_HEAD_REGISTRY = {
|
||||
"dpt": DPT,
|
||||
"dualdpt": DualDPT,
|
||||
}
|
||||
|
||||
|
||||
# Backbone presets (mirror the upstream DINOv2 ViT variants).
|
||||
_BACKBONE_PRESETS = {
|
||||
"vits": dict(hidden_size=384, num_hidden_layers=12, num_attention_heads=6, use_swiglu_ffn=False),
|
||||
"vitb": dict(hidden_size=768, num_hidden_layers=12, num_attention_heads=12, use_swiglu_ffn=False),
|
||||
"vitl": dict(hidden_size=1024, num_hidden_layers=24, num_attention_heads=16, use_swiglu_ffn=False),
|
||||
"vitg": dict(hidden_size=1536, num_hidden_layers=40, num_attention_heads=24, use_swiglu_ffn=True),
|
||||
}
|
||||
|
||||
|
||||
def _build_backbone_config(
|
||||
backbone_name: str,
|
||||
*,
|
||||
alt_start: int,
|
||||
qknorm_start: int,
|
||||
rope_start: int,
|
||||
cat_token: bool,
|
||||
) -> dict:
|
||||
if backbone_name not in _BACKBONE_PRESETS:
|
||||
raise ValueError(f"Unknown DINOv2 backbone variant: {backbone_name!r}")
|
||||
cfg = dict(_BACKBONE_PRESETS[backbone_name])
|
||||
cfg.update(dict(
|
||||
layer_norm_eps=1e-6,
|
||||
patch_size=14,
|
||||
image_size=518,
|
||||
# No mask_token in DA3 weights; omit param to avoid load warnings.
|
||||
use_mask_token=False,
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
rope_freq=100.0,
|
||||
))
|
||||
return cfg
|
||||
|
||||
|
||||
class DepthAnything3Net(nn.Module):
|
||||
|
||||
PATCH_SIZE = 14
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
# --- Backbone ---
|
||||
backbone_name: str = "vitl",
|
||||
out_layers: Sequence[int] = (4, 11, 17, 23),
|
||||
alt_start: int = -1,
|
||||
qknorm_start: int = -1,
|
||||
rope_start: int = -1,
|
||||
cat_token: bool = False,
|
||||
# --- Head ---
|
||||
head_type: str = "dpt", # dpt or dualdpt
|
||||
head_dim_in: int = 1024,
|
||||
head_output_dim: int = 1, # 1 = depth only, 2 = depth+conf
|
||||
head_features: int = 256,
|
||||
head_out_channels: Sequence[int] = (256, 512, 1024, 1024),
|
||||
head_use_sky_head: bool = True, # ignored by DualDPT
|
||||
head_pos_embed: Optional[bool] = None, # default: True for DualDPT, False for DPT
|
||||
# --- Camera (multi-view) ---
|
||||
has_cam_enc: bool = False,
|
||||
has_cam_dec: bool = False,
|
||||
cam_dim_out: Optional[int] = None, # CameraEnc dim_out (defaults to embed_dim)
|
||||
cam_dec_dim_in: Optional[int] = None, # CameraDec dim_in (defaults to 2*embed_dim with cat_token)
|
||||
# ComfyUI plumbing
|
||||
device=None, dtype=None, operations=None,
|
||||
**_ignored,
|
||||
):
|
||||
super().__init__()
|
||||
head_cls = _HEAD_REGISTRY[head_type.lower()]
|
||||
self.head_type = head_type.lower()
|
||||
self.has_sky = (self.head_type == "dpt") and head_use_sky_head
|
||||
self.has_conf = head_output_dim > 1
|
||||
self.out_layers = list(out_layers)
|
||||
|
||||
backbone_cfg = _build_backbone_config(
|
||||
backbone_name,
|
||||
alt_start=alt_start,
|
||||
qknorm_start=qknorm_start,
|
||||
rope_start=rope_start,
|
||||
cat_token=cat_token,
|
||||
)
|
||||
self.backbone = Dinov2Model(backbone_cfg, dtype, device, operations)
|
||||
|
||||
head_kwargs = dict(
|
||||
dim_in=head_dim_in,
|
||||
patch_size=self.PATCH_SIZE,
|
||||
output_dim=head_output_dim,
|
||||
features=head_features,
|
||||
out_channels=tuple(head_out_channels),
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
if self.head_type == "dpt":
|
||||
head_kwargs.update(
|
||||
use_sky_head=head_use_sky_head,
|
||||
pos_embed=(False if head_pos_embed is None else head_pos_embed),
|
||||
)
|
||||
else: # dualdpt
|
||||
head_kwargs.update(
|
||||
pos_embed=(True if head_pos_embed is None else head_pos_embed),
|
||||
)
|
||||
self.head = head_cls(**head_kwargs)
|
||||
|
||||
# Built only if checkpoint has weights; cam_enc output dim == embed_dim.
|
||||
embed_dim = backbone_cfg["hidden_size"]
|
||||
if has_cam_enc:
|
||||
self.cam_enc = CameraEnc(
|
||||
dim_out=cam_dim_out if cam_dim_out is not None else embed_dim,
|
||||
num_heads=max(1, embed_dim // 64),
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
else:
|
||||
self.cam_enc = None
|
||||
if has_cam_dec:
|
||||
default_dim = embed_dim * (2 if cat_token else 1)
|
||||
self.cam_dec = CameraDec(
|
||||
dim_in=cam_dec_dim_in if cam_dec_dim_in is not None else default_dim,
|
||||
device=device, dtype=dtype, operations=operations,
|
||||
)
|
||||
else:
|
||||
self.cam_dec = None
|
||||
|
||||
self.dtype = dtype
|
||||
|
||||
def forward(
|
||||
self,
|
||||
image: torch.Tensor,
|
||||
extrinsics: Optional[torch.Tensor] = None,
|
||||
intrinsics: Optional[torch.Tensor] = None,
|
||||
*,
|
||||
use_ray_pose: bool = False,
|
||||
ref_view_strategy: str = "saddle_balanced",
|
||||
export_feat_layers: Optional[Sequence[int]] = None,
|
||||
**_unused,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""Run depth and optionally pose prediction."""
|
||||
if image.ndim == 4:
|
||||
image = image.unsqueeze(1) # (B, 1, 3, H, W)
|
||||
assert image.ndim == 5 and image.shape[2] == 3, \
|
||||
f"image must be (B,3,H,W) or (B,S,3,H,W); got {tuple(image.shape)}"
|
||||
|
||||
B, S, _, H, W = image.shape
|
||||
assert H % self.PATCH_SIZE == 0 and W % self.PATCH_SIZE == 0, \
|
||||
f"image H,W must be multiples of {self.PATCH_SIZE}; got {(H, W)}"
|
||||
|
||||
# Camera-token preparation (multi-view path).
|
||||
cam_token = None
|
||||
if extrinsics is not None and intrinsics is not None and self.cam_enc is not None:
|
||||
cam_token = self.cam_enc(extrinsics, intrinsics, (H, W))
|
||||
|
||||
# Toggle aux ray output on/off depending on what the caller asked for.
|
||||
if isinstance(self.head, DualDPT):
|
||||
self.head.enable_aux = bool(use_ray_pose)
|
||||
|
||||
feats, aux_feats = self.backbone.get_intermediate_layers_da3(
|
||||
image, self.out_layers, cam_token=cam_token,
|
||||
ref_view_strategy=ref_view_strategy,
|
||||
export_feat_layers=export_feat_layers,
|
||||
)
|
||||
head_out = self.head(feats, H=H, W=W, patch_start_idx=0)
|
||||
|
||||
# Pose prediction.
|
||||
out: Dict[str, torch.Tensor] = {}
|
||||
if use_ray_pose and "ray" in head_out and "ray_conf" in head_out:
|
||||
ray = head_out["ray"]
|
||||
ray_conf = head_out["ray_conf"]
|
||||
extr_c2w, focal, pp = get_extrinsic_from_camray(
|
||||
ray, ray_conf, ray.shape[-3], ray.shape[-2],
|
||||
)
|
||||
# Match the upstream output: w2c, drop the homogeneous row.
|
||||
extr_w2c = affine_inverse(extr_c2w)[:, :, :3, :]
|
||||
# Build pixel-space intrinsics from the normalised focal/pp output.
|
||||
intr = torch.eye(3, device=ray.device, dtype=ray.dtype)
|
||||
intr = intr[None, None].expand(extr_c2w.shape[0], extr_c2w.shape[1], 3, 3).clone()
|
||||
intr[:, :, 0, 0] = focal[:, :, 0] / 2 * W
|
||||
intr[:, :, 1, 1] = focal[:, :, 1] / 2 * H
|
||||
intr[:, :, 0, 2] = pp[:, :, 0] * W * 0.5
|
||||
intr[:, :, 1, 2] = pp[:, :, 1] * H * 0.5
|
||||
out["extrinsics"] = extr_w2c
|
||||
out["intrinsics"] = intr
|
||||
elif self.cam_dec is not None and S > 1:
|
||||
# Decode the cam-token of the final out_layer into a pose encoding.
|
||||
cam_feat = feats[-1][1] # (B, S, dim_in_to_cam_dec)
|
||||
pose_enc = self.cam_dec(cam_feat)
|
||||
c2w_3x4, intr = pose_encoding_to_extri_intri(pose_enc, (H, W))
|
||||
# Match the upstream output convention: w2c (world->camera), 3x4.
|
||||
c2w_4x4 = torch.cat([
|
||||
c2w_3x4,
|
||||
torch.tensor([0, 0, 0, 1], device=c2w_3x4.device, dtype=c2w_3x4.dtype)
|
||||
.view(1, 1, 1, 4).expand(B, S, 1, 4),
|
||||
], dim=-2)
|
||||
out["extrinsics"] = affine_inverse(c2w_4x4)[:, :, :3, :]
|
||||
out["intrinsics"] = intr
|
||||
|
||||
# Flatten the views axis for per-pixel outputs (depth/conf/sky) so the
|
||||
# per-image consumer keeps its (B*S, H, W) interface.
|
||||
for k, v in head_out.items():
|
||||
if k in ("ray", "ray_conf"):
|
||||
# Keep multi-view shape for downstream pose work.
|
||||
out[k] = v
|
||||
elif v.ndim >= 3 and v.shape[0] == B and v.shape[1] == S:
|
||||
out[k] = v.reshape(B * S, *v.shape[2:])
|
||||
else:
|
||||
out[k] = v
|
||||
|
||||
if export_feat_layers:
|
||||
out["aux_features"] = self._reshape_aux_features(aux_feats, H, W)
|
||||
return out
|
||||
|
||||
def _reshape_aux_features(self, aux_feats, H: int, W: int):
|
||||
"""Reshape (B, S, N, C) aux features into (B, S, h_p, w_p, C)."""
|
||||
ph, pw = H // self.PATCH_SIZE, W // self.PATCH_SIZE
|
||||
out = []
|
||||
for f in aux_feats:
|
||||
B, S, N, C = f.shape
|
||||
assert N == ph * pw, f"aux feature seq mismatch: {N} != {ph}*{pw}"
|
||||
out.append(f.reshape(B, S, ph, pw, C))
|
||||
return out
|
||||
128
comfy/ldm/depth_anything_3/preprocess.py
Normal file
128
comfy/ldm/depth_anything_3/preprocess.py
Normal file
@ -0,0 +1,128 @@
|
||||
"""Input/output preprocessing helpers for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.utils
|
||||
|
||||
PATCH_SIZE = 14
|
||||
|
||||
# ImageNet normalization constants used during DA3 training.
|
||||
_IMAGENET_MEAN = torch.tensor([0.485, 0.456, 0.406])
|
||||
_IMAGENET_STD = torch.tensor([0.229, 0.224, 0.225])
|
||||
|
||||
|
||||
def _round_to_patch(x: int, patch: int = PATCH_SIZE) -> int:
|
||||
down = (x // patch) * patch
|
||||
up = down + patch
|
||||
return up if abs(up - x) <= abs(x - down) else down
|
||||
|
||||
|
||||
def compute_target_size(orig_h: int, orig_w: int, process_res: int, method: str = "upper_bound_resize") -> Tuple[int, int]:
|
||||
"""Compute (target_h, target_w) for a single image.
|
||||
upper_bound_resize: scale longest side to process_res, then round each dim to nearest multiple of 14 (default upstream method).
|
||||
lower_bound_resize: scale shortest side to process_res, then round."""
|
||||
|
||||
if method == "upper_bound_resize":
|
||||
longest = max(orig_h, orig_w)
|
||||
scale = process_res / float(longest)
|
||||
elif method == "lower_bound_resize":
|
||||
shortest = min(orig_h, orig_w)
|
||||
scale = process_res / float(shortest)
|
||||
else:
|
||||
raise ValueError(f"Unsupported process_res_method: {method}")
|
||||
|
||||
new_w = max(1, _round_to_patch(int(round(orig_w * scale))))
|
||||
new_h = max(1, _round_to_patch(int(round(orig_h * scale))))
|
||||
return new_h, new_w
|
||||
|
||||
|
||||
def preprocess_image(image: torch.Tensor, process_res: int = 504, method: str = "upper_bound_resize") -> torch.Tensor:
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
B, H, W, _ = image.shape
|
||||
target_h, target_w = compute_target_size(H, W, process_res, method)
|
||||
|
||||
# (B, H, W, 3) -> (B, 3, H, W)
|
||||
x = image.movedim(-1, 1).contiguous()
|
||||
if (target_h, target_w) != (H, W):
|
||||
# Upstream uses cv2 INTER_CUBIC (upscale) / INTER_AREA (downscale).
|
||||
# Lanczos in ``common_upscale`` is anti-aliased and produces the
|
||||
# closest pixel-wise match in a sweep across {bilinear, bicubic,
|
||||
# area, lanczos, bislerp}. Used in both directions for simplicity.
|
||||
x = comfy.utils.common_upscale(x.float(), target_w, target_h, "lanczos", "disabled",)
|
||||
x = x.clamp(0.0, 1.0)
|
||||
|
||||
mean = _IMAGENET_MEAN.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
std = _IMAGENET_STD.to(device=x.device, dtype=x.dtype).view(1, 3, 1, 1)
|
||||
x = (x - mean) / std
|
||||
return x
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Output post-processing (sky-aware clipping for Mono/Metric variants)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def compute_non_sky_mask(sky_prediction: torch.Tensor, threshold: float = 0.3) -> torch.Tensor:
|
||||
"""Boolean mask: True for non-sky pixels (sky probability < threshold)."""
|
||||
return sky_prediction < threshold
|
||||
|
||||
|
||||
def apply_sky_aware_clip(depth: torch.Tensor, sky: torch.Tensor, threshold: float = 0.3, quantile: float = 0.99) -> torch.Tensor:
|
||||
"""Clips sky regions to the 99th percentile of non-sky depth. Returns a new depth tensor."""
|
||||
non_sky = compute_non_sky_mask(sky, threshold=threshold)
|
||||
if non_sky.sum() <= 10 or (~non_sky).sum() <= 10:
|
||||
return depth.clone()
|
||||
|
||||
non_sky_depth = depth[non_sky]
|
||||
if non_sky_depth.numel() > 100_000:
|
||||
idx = torch.randint(0, non_sky_depth.numel(), (100_000,), device=non_sky_depth.device)
|
||||
sampled = non_sky_depth[idx]
|
||||
else:
|
||||
sampled = non_sky_depth
|
||||
|
||||
max_depth = torch.quantile(sampled, quantile)
|
||||
out = depth.clone()
|
||||
out[~non_sky] = max_depth
|
||||
return out
|
||||
|
||||
|
||||
def normalize_depth_v2_style(depth: torch.Tensor, sky: torch.Tensor | None = None, low_quantile: float = 0.01, high_quantile: float = 0.99) -> torch.Tensor:
|
||||
"""V2-style normalization computes percentile bounds over non-sky pixels (when available), then maps depth into [0, 1] with near = white (1.0)."""
|
||||
if sky is not None:
|
||||
mask = compute_non_sky_mask(sky)
|
||||
if mask.any():
|
||||
valid = depth[mask]
|
||||
else:
|
||||
valid = depth.flatten()
|
||||
else:
|
||||
valid = depth.flatten()
|
||||
|
||||
if valid.numel() > 100_000:
|
||||
idx = torch.randint(0, valid.numel(), (100_000,), device=valid.device)
|
||||
sample = valid[idx]
|
||||
else:
|
||||
sample = valid
|
||||
|
||||
lo = torch.quantile(sample, low_quantile)
|
||||
hi = torch.quantile(sample, high_quantile)
|
||||
rng = (hi - lo).clamp(min=1e-6)
|
||||
norm = ((depth - lo) / rng).clamp(0.0, 1.0)
|
||||
# Nearer pixels are brighter (1.0)
|
||||
norm = 1.0 - norm
|
||||
if sky is not None:
|
||||
# Sky pixels become black (far / unknown)
|
||||
sky_mask = ~compute_non_sky_mask(sky)
|
||||
norm = torch.where(sky_mask, torch.zeros_like(norm), norm)
|
||||
return norm
|
||||
|
||||
|
||||
def normalize_depth_min_max(depth: torch.Tensor) -> torch.Tensor:
|
||||
"""Simple per-frame min/max normalization with near=1.0 convention."""
|
||||
lo = depth.amin(dim=(-2, -1), keepdim=True)
|
||||
hi = depth.amax(dim=(-2, -1), keepdim=True)
|
||||
rng = (hi - lo).clamp(min=1e-6)
|
||||
return 1.0 - ((depth - lo) / rng).clamp(0.0, 1.0)
|
||||
272
comfy/ldm/depth_anything_3/ray_pose.py
Normal file
272
comfy/ldm/depth_anything_3/ray_pose.py
Normal file
@ -0,0 +1,272 @@
|
||||
"""Ray-to-pose conversion for the multi-view path of Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Optional, Tuple
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
# qr/svd use fp32: CUDA often has no fp16/bf16 kernels for these ops.
|
||||
|
||||
|
||||
def _ql_decomposition(A: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Decompose A = Q @ L with Q orthogonal and L lower-triangular.
|
||||
Implemented in terms of QR by reversing the columns/rows; the standard
|
||||
trick from the upstream reference. Inputs A are (3, 3)."""
|
||||
P = torch.tensor([[0, 0, 1], [0, 1, 0], [1, 0, 0]], device=A.device, dtype=A.dtype)
|
||||
A_tilde = A @ P
|
||||
# CUDA QR is not implemented for fp16/bf16; upcast just for this call.
|
||||
Q_tilde, R_tilde = torch.linalg.qr(A_tilde.float())
|
||||
Q_tilde = Q_tilde.to(A.dtype)
|
||||
R_tilde = R_tilde.to(A.dtype)
|
||||
Q = Q_tilde @ P
|
||||
L = P @ R_tilde @ P
|
||||
d = torch.diag(L)
|
||||
sign = torch.sign(d)
|
||||
Q = Q * sign[None, :] # scale columns of Q
|
||||
L = L * sign[:, None] # scale rows of L
|
||||
return Q, L
|
||||
|
||||
|
||||
def _homogenize_points(points: torch.Tensor) -> torch.Tensor:
|
||||
return torch.cat([points, torch.ones_like(points[..., :1])], dim=-1)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Weighted-LSQ + RANSAC homography (batched)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _find_homography_weighted_lsq(src_pts: torch.Tensor, dst_pts: torch.Tensor, confident_weight: torch.Tensor,) -> torch.Tensor:
|
||||
"""Solve a single H with weighted least-squares (DLT)."""
|
||||
N = src_pts.shape[0]
|
||||
if N < 4:
|
||||
raise ValueError("At least 4 points are required to compute a homography.")
|
||||
w = confident_weight.sqrt().unsqueeze(1) # (N, 1)
|
||||
x = src_pts[:, 0:1]
|
||||
y = src_pts[:, 1:2]
|
||||
u = dst_pts[:, 0:1]
|
||||
v = dst_pts[:, 1:2]
|
||||
zeros = torch.zeros_like(x)
|
||||
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=1)
|
||||
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=1)
|
||||
A = torch.cat([A1, A2], dim=0) # (2N, 9)
|
||||
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
|
||||
_, _, Vh = torch.linalg.svd(A.float())
|
||||
Vh = Vh.to(A.dtype)
|
||||
H = Vh[-1].reshape(3, 3)
|
||||
return H / H[-1, -1]
|
||||
|
||||
|
||||
def _find_homography_weighted_lsq_batched(src_pts_batch: torch.Tensor, dst_pts_batch: torch.Tensor, confident_weight_batch: torch.Tensor) -> torch.Tensor:
|
||||
"""Batched DLT solver. Inputs (B, K, 2) / (B, K); output (B, 3, 3)."""
|
||||
B, K, _ = src_pts_batch.shape
|
||||
w = confident_weight_batch.sqrt().unsqueeze(2)
|
||||
x = src_pts_batch[:, :, 0:1]
|
||||
y = src_pts_batch[:, :, 1:2]
|
||||
u = dst_pts_batch[:, :, 0:1]
|
||||
v = dst_pts_batch[:, :, 1:2]
|
||||
zeros = torch.zeros_like(x)
|
||||
A1 = torch.cat([-x * w, -y * w, -w, zeros, zeros, zeros, x * u * w, y * u * w, u * w], dim=2)
|
||||
A2 = torch.cat([zeros, zeros, zeros, -x * w, -y * w, -w, x * v * w, y * v * w, v * w], dim=2)
|
||||
A = torch.cat([A1, A2], dim=1) # (B, 2K, 9)
|
||||
# CUDA SVD is not implemented for fp16/bf16; upcast just for this call.
|
||||
_, _, Vh = torch.linalg.svd(A.float())
|
||||
Vh = Vh.to(A.dtype)
|
||||
H = Vh[:, -1].reshape(B, 3, 3)
|
||||
return H / H[:, 2:3, 2:3]
|
||||
|
||||
|
||||
def _ransac_find_homography_weighted_batched(
|
||||
src_pts: torch.Tensor, # (B, N, 2)
|
||||
dst_pts: torch.Tensor, # (B, N, 2)
|
||||
confident_weight: torch.Tensor, # (B, N)
|
||||
n_sample: int,
|
||||
n_iter: int = 100,
|
||||
reproj_threshold: float = 3.0,
|
||||
num_sample_for_ransac: int = 8,
|
||||
max_inlier_num: int = 10000,
|
||||
rand_sample_iters_idx: Optional[torch.Tensor] = None,
|
||||
) -> torch.Tensor:
|
||||
"""Batched weighted-RANSAC homography estimator. Returns (B, 3, 3) homography matrices."""
|
||||
B, N, _ = src_pts.shape
|
||||
assert N >= 4
|
||||
device = src_pts.device
|
||||
|
||||
sorted_idx = torch.argsort(confident_weight, descending=True, dim=1)
|
||||
candidate_idx = sorted_idx[:, :n_sample] # (B, n_sample)
|
||||
|
||||
if rand_sample_iters_idx is None:
|
||||
rand_sample_iters_idx = torch.stack(
|
||||
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac]
|
||||
for _ in range(n_iter)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
rand_idx = candidate_idx[:, rand_sample_iters_idx] # (B, n_iter, k)
|
||||
b_idx = (
|
||||
torch.arange(B, device=device)
|
||||
.view(B, 1, 1)
|
||||
.expand(B, n_iter, num_sample_for_ransac)
|
||||
)
|
||||
src_b = src_pts[b_idx, rand_idx]
|
||||
dst_b = dst_pts[b_idx, rand_idx]
|
||||
w_b = confident_weight[b_idx, rand_idx]
|
||||
|
||||
cB, cN = src_b.shape[:2]
|
||||
H_batch = _find_homography_weighted_lsq_batched(
|
||||
src_b.flatten(0, 1), dst_b.flatten(0, 1), w_b.flatten(0, 1),
|
||||
).unflatten(0, (cB, cN)) # (B, n_iter, 3, 3)
|
||||
|
||||
src_homo = torch.cat([src_pts, torch.ones(B, N, 1, device=device, dtype=src_pts.dtype)], dim=2)
|
||||
proj = torch.bmm(
|
||||
src_homo.unsqueeze(1).expand(B, n_iter, N, 3).reshape(-1, N, 3),
|
||||
H_batch.reshape(-1, 3, 3).transpose(1, 2),
|
||||
) # (B*n_iter, N, 3)
|
||||
proj_xy = (proj[:, :, :2] / proj[:, :, 2:3]).reshape(B, n_iter, N, 2)
|
||||
err = ((proj_xy - dst_pts.unsqueeze(1)) ** 2).sum(-1).sqrt() # (B, n_iter, N)
|
||||
inlier_mask = err < reproj_threshold
|
||||
score = (inlier_mask * confident_weight.unsqueeze(1)).sum(dim=2)
|
||||
best_idx = torch.argmax(score, dim=1)
|
||||
best_inlier_mask = inlier_mask[torch.arange(B, device=device), best_idx]
|
||||
|
||||
# Refit with the inlier set (per-batch, since the inlier counts vary).
|
||||
H_inlier_list = []
|
||||
for b in range(B):
|
||||
mask = best_inlier_mask[b]
|
||||
in_src = src_pts[b][mask]
|
||||
in_dst = dst_pts[b][mask]
|
||||
in_w = confident_weight[b][mask]
|
||||
if in_src.shape[0] < 4:
|
||||
# Fall back to identity when RANSAC fails to find enough inliers.
|
||||
H_inlier_list.append(torch.eye(3, device=device, dtype=src_pts.dtype))
|
||||
continue
|
||||
sorted_w = torch.argsort(in_w, descending=True)
|
||||
if len(sorted_w) > max_inlier_num:
|
||||
keep = max(int(len(sorted_w) * 0.95), max_inlier_num)
|
||||
sorted_w = sorted_w[:keep][torch.randperm(keep, device=device)[:max_inlier_num]]
|
||||
H_inlier_list.append(
|
||||
_find_homography_weighted_lsq(in_src[sorted_w], in_dst[sorted_w], in_w[sorted_w])
|
||||
)
|
||||
return torch.stack(H_inlier_list, dim=0)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Camera-ray utilities
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _unproject_identity(num_y: int, num_x: int, B: int, S: int, device, dtype) -> torch.Tensor:
|
||||
"""Camera-space unit rays for an identity intrinsic on a 2x2 image plane."""
|
||||
dx = 1.0 / num_x
|
||||
dy = 1.0 / num_y
|
||||
# Centered camera-space coords directly (skip the K^-1 step since it's
|
||||
# just a translation by -1 on x and y when K is identity-with-center=1).
|
||||
y = torch.linspace(-(1 - dy), (1 - dy), num_y, device=device, dtype=dtype)
|
||||
x = torch.linspace(-(1 - dx), (1 - dx), num_x, device=device, dtype=dtype)
|
||||
yy, xx = torch.meshgrid(y, x, indexing="ij")
|
||||
grid = torch.stack((xx, yy), dim=-1) # (h, w, 2)
|
||||
grid = grid.unsqueeze(0).unsqueeze(0).expand(B, S, num_y, num_x, 2)
|
||||
return torch.cat([grid, torch.ones_like(grid[..., :1])], dim=-1)
|
||||
|
||||
|
||||
def _camray_to_caminfo(
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
confidence: Optional[torch.Tensor] = None, # (B, S, h, w)
|
||||
reproj_threshold: float = 0.2,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Convert per-pixel camera rays to per-view (R, T, focal, principal)."""
|
||||
if confidence is None:
|
||||
confidence = torch.ones_like(camray[..., 0])
|
||||
B, S, h, w, _ = camray.shape
|
||||
device = camray.device
|
||||
dtype = camray.dtype
|
||||
|
||||
rays_target = camray[..., :3] # (B, S, h, w, 3)
|
||||
rays_origin = _unproject_identity(h, w, B, S, device, dtype)
|
||||
|
||||
# Flatten (B*S, h*w, *) for the RANSAC routine.
|
||||
rays_target = rays_target.flatten(0, 1).flatten(1, 2)
|
||||
rays_origin = rays_origin.flatten(0, 1).flatten(1, 2)
|
||||
weights = confidence.flatten(0, 1).flatten(1, 2).clone()
|
||||
|
||||
# Project to 2D in homogeneous form (the upstream calls this "perspective division").
|
||||
z_thresh = 1e-4
|
||||
mask = (rays_target[:, :, 2].abs() > z_thresh) & (rays_origin[:, :, 2].abs() > z_thresh)
|
||||
weights = torch.where(mask, weights, torch.zeros_like(weights))
|
||||
src = rays_origin.clone()
|
||||
dst = rays_target.clone()
|
||||
src[..., 0] = torch.where(mask, src[..., 0] / src[..., 2], src[..., 0])
|
||||
src[..., 1] = torch.where(mask, src[..., 1] / src[..., 2], src[..., 1])
|
||||
dst[..., 0] = torch.where(mask, dst[..., 0] / dst[..., 2], dst[..., 0])
|
||||
dst[..., 1] = torch.where(mask, dst[..., 1] / dst[..., 2], dst[..., 1])
|
||||
src = src[..., :2]
|
||||
dst = dst[..., :2]
|
||||
|
||||
N = src.shape[1]
|
||||
n_iter = 100
|
||||
sample_ratio = 0.3
|
||||
num_sample_for_ransac = 8
|
||||
n_sample = max(num_sample_for_ransac, int(N * sample_ratio))
|
||||
rand_idx = torch.stack(
|
||||
[torch.randperm(n_sample, device=device)[:num_sample_for_ransac] for _ in range(n_iter)],
|
||||
dim=0,
|
||||
)
|
||||
|
||||
# Chunk along the view axis to keep peak memory predictable.
|
||||
chunk = 2
|
||||
A_list = []
|
||||
for i in range(0, src.shape[0], chunk):
|
||||
A = _ransac_find_homography_weighted_batched(
|
||||
src[i:i + chunk], dst[i:i + chunk], weights[i:i + chunk],
|
||||
n_sample=n_sample, n_iter=n_iter,
|
||||
num_sample_for_ransac=num_sample_for_ransac,
|
||||
reproj_threshold=reproj_threshold,
|
||||
rand_sample_iters_idx=rand_idx,
|
||||
max_inlier_num=8000,
|
||||
)
|
||||
# Flip sign on dets that come out < 0 (so that the QL produces a
|
||||
# right-handed rotation). ``det`` lacks fp16/bf16 CUDA kernels, so
|
||||
# do the comparison in fp32.
|
||||
flip = torch.linalg.det(A.float()) < 0
|
||||
A = torch.where(flip[:, None, None], -A, A)
|
||||
A_list.append(A)
|
||||
A = torch.cat(A_list, dim=0) # (B*S, 3, 3)
|
||||
|
||||
R_list, f_list, pp_list = [], [], []
|
||||
for i in range(A.shape[0]):
|
||||
R, L = _ql_decomposition(A[i])
|
||||
L = L / L[2][2]
|
||||
f_list.append(torch.stack((L[0][0], L[1][1])))
|
||||
pp_list.append(torch.stack((L[2][0], L[2][1])))
|
||||
R_list.append(R)
|
||||
R = torch.stack(R_list).reshape(B, S, 3, 3)
|
||||
focal = torch.stack(f_list).reshape(B, S, 2)
|
||||
pp = torch.stack(pp_list).reshape(B, S, 2)
|
||||
|
||||
# Translation: confidence-weighted average of camray direction(s).
|
||||
cf = confidence.flatten(0, 1).flatten(1, 2)
|
||||
T = (camray.flatten(0, 1).flatten(1, 2)[..., 3:] * cf.unsqueeze(-1)).sum(dim=1)
|
||||
T = T / cf.sum(dim=-1, keepdim=True)
|
||||
T = T.reshape(B, S, 3)
|
||||
|
||||
# Match upstream output convention: focal -> 1/focal, pp + 1.
|
||||
return R, T, 1.0 / focal, pp + 1.0
|
||||
|
||||
|
||||
def get_extrinsic_from_camray(
|
||||
camray: torch.Tensor, # (B, S, h, w, 6)
|
||||
conf: torch.Tensor, # (B, S, h, w, 1) or (B, S, h, w)
|
||||
patch_size_y: int,
|
||||
patch_size_x: int,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Wrap a 4x4 extrinsic + per-view focal + principal-point output."""
|
||||
if conf.ndim == 5 and conf.shape[-1] == 1:
|
||||
conf = conf.squeeze(-1)
|
||||
R, T, focal, pp = _camray_to_caminfo(camray, confidence=conf)
|
||||
extr = torch.cat([R, T.unsqueeze(-1)], dim=-1) # (B, S, 3, 4)
|
||||
homo_row = torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device)
|
||||
homo_row = homo_row.view(1, 1, 1, 4).expand(R.shape[0], R.shape[1], 1, 4)
|
||||
extr = torch.cat([extr, homo_row], dim=-2) # (B, S, 4, 4)
|
||||
return extr, focal, pp
|
||||
87
comfy/ldm/depth_anything_3/reference_view_selector.py
Normal file
87
comfy/ldm/depth_anything_3/reference_view_selector.py
Normal file
@ -0,0 +1,87 @@
|
||||
"""Reference-view selection for the multi-view path of Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Literal
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
RefViewStrategy = Literal["first", "middle", "saddle_balanced", "saddle_sim_range"]
|
||||
|
||||
|
||||
# Per the upstream constants module: ``THRESH_FOR_REF_SELECTION = 3``.
|
||||
# Reference selection only runs when there are at least this many views.
|
||||
THRESH_FOR_REF_SELECTION: int = 3
|
||||
|
||||
|
||||
def select_reference_view(x: torch.Tensor, strategy: RefViewStrategy = "saddle_balanced") -> torch.Tensor:
|
||||
"""Pick a reference view index per batch element."""
|
||||
B, S, _, _ = x.shape
|
||||
if S <= 1:
|
||||
return torch.zeros(B, dtype=torch.long, device=x.device)
|
||||
if strategy == "first":
|
||||
return torch.zeros(B, dtype=torch.long, device=x.device)
|
||||
if strategy == "middle":
|
||||
return torch.full((B,), S // 2, dtype=torch.long, device=x.device)
|
||||
|
||||
# Feature-based strategies: normalised cls/cam token per view.
|
||||
img_class_feat = x[:, :, 0] / x[:, :, 0].norm(dim=-1, keepdim=True) # (B,S,C)
|
||||
|
||||
if strategy == "saddle_balanced":
|
||||
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2)) # (B,S,S)
|
||||
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
|
||||
sim_score = sim_no_diag.sum(dim=-1) / (S - 1) # (B,S)
|
||||
feat_norm = x[:, :, 0].norm(dim=-1) # (B,S)
|
||||
feat_var = img_class_feat.var(dim=-1) # (B,S)
|
||||
|
||||
def _normalize(metric):
|
||||
mn = metric.min(dim=1, keepdim=True).values
|
||||
mx = metric.max(dim=1, keepdim=True).values
|
||||
return (metric - mn) / (mx - mn + 1e-8)
|
||||
|
||||
sim_n, norm_n, var_n = _normalize(sim_score), _normalize(feat_norm), _normalize(feat_var)
|
||||
balance = (sim_n - 0.5).abs() + (norm_n - 0.5).abs() + (var_n - 0.5).abs()
|
||||
return balance.argmin(dim=1)
|
||||
|
||||
if strategy == "saddle_sim_range":
|
||||
sim = torch.matmul(img_class_feat, img_class_feat.transpose(1, 2))
|
||||
sim_no_diag = sim - torch.eye(S, device=sim.device).unsqueeze(0)
|
||||
sim_max = sim_no_diag.max(dim=-1).values
|
||||
sim_min = sim_no_diag.min(dim=-1).values
|
||||
return (sim_max - sim_min).argmax(dim=1)
|
||||
|
||||
raise ValueError(
|
||||
f"Unknown reference view selection strategy: {strategy!r}. "
|
||||
f"Must be one of: 'first', 'middle', 'saddle_balanced', 'saddle_sim_range'"
|
||||
)
|
||||
|
||||
|
||||
def reorder_by_reference(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
||||
"""Reorder x so the reference view is at position 0 in axis S."""
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
if S <= 1:
|
||||
return x
|
||||
positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
b_idx_exp = b_idx.unsqueeze(1)
|
||||
reorder = torch.where(
|
||||
(positions > 0) & (positions <= b_idx_exp),
|
||||
positions - 1,
|
||||
positions,
|
||||
)
|
||||
reorder[:, 0] = b_idx
|
||||
batch = torch.arange(B, device=x.device).unsqueeze(1)
|
||||
return x[batch, reorder]
|
||||
|
||||
|
||||
def restore_original_order(x: torch.Tensor, b_idx: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse of reorder_by_reference."""
|
||||
B, S = x.shape[0], x.shape[1]
|
||||
if S <= 1:
|
||||
return x
|
||||
target_positions = torch.arange(S, device=x.device).unsqueeze(0).expand(B, -1)
|
||||
b_idx_exp = b_idx.unsqueeze(1)
|
||||
restore = torch.where(target_positions < b_idx_exp, target_positions + 1, target_positions)
|
||||
restore = torch.scatter(restore, dim=1, index=b_idx_exp, src=torch.zeros_like(b_idx_exp))
|
||||
batch = torch.arange(B, device=x.device).unsqueeze(1)
|
||||
return x[batch, restore]
|
||||
160
comfy/ldm/depth_anything_3/transform.py
Normal file
160
comfy/ldm/depth_anything_3/transform.py
Normal file
@ -0,0 +1,160 @@
|
||||
"""Geometry / camera transform helpers for Depth Anything 3."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Tuple
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Affine 4x4 helpers
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def as_homogeneous(ext: torch.Tensor) -> torch.Tensor:
|
||||
"""Promote (...,3,4) extrinsics to (...,4,4) homogeneous form. No-op when the input is already ``(...,4,4)``."""
|
||||
if ext.shape[-2:] == (4, 4):
|
||||
return ext
|
||||
if ext.shape[-2:] == (3, 4):
|
||||
ones = torch.zeros_like(ext[..., :1, :4])
|
||||
ones[..., 0, 3] = 1.0
|
||||
return torch.cat([ext, ones], dim=-2)
|
||||
raise ValueError(f"Invalid affine shape: {ext.shape}")
|
||||
|
||||
|
||||
def affine_inverse(A: torch.Tensor) -> torch.Tensor:
|
||||
"""Inverse of an affine matrix ``[R|T; 0 0 0 1]``."""
|
||||
R = A[..., :3, :3]
|
||||
T = A[..., :3, 3:]
|
||||
P = A[..., 3:, :]
|
||||
return torch.cat([torch.cat([R.mT, -R.mT @ T], dim=-1), P], dim=-2)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Quaternion <-> rotation matrix (xyzw / scalar-last)
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _sqrt_positive_part(x: torch.Tensor) -> torch.Tensor:
|
||||
"""sqrt(max(0, x)) with a zero subgradient where x == 0."""
|
||||
ret = torch.zeros_like(x)
|
||||
positive_mask = x > 0
|
||||
if torch.is_grad_enabled():
|
||||
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
||||
else:
|
||||
ret = torch.where(positive_mask, torch.sqrt(x), ret)
|
||||
return ret
|
||||
|
||||
|
||||
def standardize_quaternion(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""Force the real part of a unit quaternion (xyzw) to be non-negative."""
|
||||
return torch.where(quaternions[..., 3:4] < 0, -quaternions, quaternions)
|
||||
|
||||
|
||||
def quat_to_mat(quaternions: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert quaternions (xyzw) to (...,3,3) rotation matrices."""
|
||||
i, j, k, r = torch.unbind(quaternions, -1)
|
||||
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
||||
o = torch.stack(
|
||||
(
|
||||
1 - two_s * (j * j + k * k),
|
||||
two_s * (i * j - k * r),
|
||||
two_s * (i * k + j * r),
|
||||
two_s * (i * j + k * r),
|
||||
1 - two_s * (i * i + k * k),
|
||||
two_s * (j * k - i * r),
|
||||
two_s * (i * k - j * r),
|
||||
two_s * (j * k + i * r),
|
||||
1 - two_s * (i * i + j * j),
|
||||
),
|
||||
-1,
|
||||
)
|
||||
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
||||
|
||||
|
||||
def mat_to_quat(matrix: torch.Tensor) -> torch.Tensor:
|
||||
"""Convert (...,3,3) rotation matrices to quaternions (xyzw)."""
|
||||
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
||||
raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.")
|
||||
|
||||
batch_dim = matrix.shape[:-2]
|
||||
m00, m01, m02, m10, m11, m12, m20, m21, m22 = torch.unbind(
|
||||
matrix.reshape(batch_dim + (9,)), dim=-1
|
||||
)
|
||||
|
||||
q_abs = _sqrt_positive_part(
|
||||
torch.stack(
|
||||
[
|
||||
1.0 + m00 + m11 + m22,
|
||||
1.0 + m00 - m11 - m22,
|
||||
1.0 - m00 + m11 - m22,
|
||||
1.0 - m00 - m11 + m22,
|
||||
],
|
||||
dim=-1,
|
||||
)
|
||||
)
|
||||
|
||||
quat_by_rijk = torch.stack(
|
||||
[
|
||||
torch.stack([q_abs[..., 0] ** 2, m21 - m12, m02 - m20, m10 - m01], dim=-1),
|
||||
torch.stack([m21 - m12, q_abs[..., 1] ** 2, m10 + m01, m02 + m20], dim=-1),
|
||||
torch.stack([m02 - m20, m10 + m01, q_abs[..., 2] ** 2, m12 + m21], dim=-1),
|
||||
torch.stack([m10 - m01, m20 + m02, m21 + m12, q_abs[..., 3] ** 2], dim=-1),
|
||||
],
|
||||
dim=-2,
|
||||
)
|
||||
|
||||
flr = torch.tensor(0.1).to(dtype=q_abs.dtype, device=q_abs.device)
|
||||
quat_candidates = quat_by_rijk / (2.0 * q_abs[..., None].max(flr))
|
||||
|
||||
out = quat_candidates[F.one_hot(q_abs.argmax(dim=-1), num_classes=4) > 0.5, :].reshape(
|
||||
batch_dim + (4,)
|
||||
)
|
||||
# Reorder rijk -> xyzw (i.e. ijkr).
|
||||
out = out[..., [1, 2, 3, 0]]
|
||||
return standardize_quaternion(out)
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Pose-encoding <-> extrinsics + intrinsics
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
def extri_intri_to_pose_encoding(extrinsics: torch.Tensor, intrinsics: torch.Tensor, image_size_hw: Tuple[int, int]) -> torch.Tensor:
|
||||
"""Pack (extr, intr, image_size) into the 9-D pose-encoding vector.
|
||||
extrinsics: camera-to-world (c2w) (B,S,4,4) matrices,
|
||||
intrinsics: pixel-space (B,S,3,3) matrices,
|
||||
image_size_hw: is a (H, W) pair.
|
||||
"""
|
||||
R = extrinsics[..., :3, :3]
|
||||
T = extrinsics[..., :3, 3]
|
||||
quat = mat_to_quat(R)
|
||||
H, W = image_size_hw
|
||||
fov_h = 2 * torch.atan((H / 2) / intrinsics[..., 1, 1])
|
||||
fov_w = 2 * torch.atan((W / 2) / intrinsics[..., 0, 0])
|
||||
return torch.cat([T, quat, fov_h[..., None], fov_w[..., None]], dim=-1).float()
|
||||
|
||||
|
||||
def pose_encoding_to_extri_intri(pose_encoding: torch.Tensor, image_size_hw: Tuple[int, int]) -> Tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Inverse of extri_intri_to_pose_encoding."""
|
||||
T = pose_encoding[..., :3]
|
||||
quat = pose_encoding[..., 3:7]
|
||||
fov_h = pose_encoding[..., 7]
|
||||
fov_w = pose_encoding[..., 8]
|
||||
# Normalize to unit quaternion. CameraDec outputs raw values; a near-zero
|
||||
# quaternion causes two_s = 2/norm² → inf in quat_to_mat → NaN extrinsics.
|
||||
quat = quat / quat.norm(dim=-1, keepdim=True).clamp(min=1e-6)
|
||||
R = quat_to_mat(quat)
|
||||
extrinsics = torch.cat([R, T[..., None]], dim=-1)
|
||||
H, W = image_size_hw
|
||||
fy = (H / 2.0) / torch.clamp(torch.tan(fov_h / 2.0), 1e-6)
|
||||
fx = (W / 2.0) / torch.clamp(torch.tan(fov_w / 2.0), 1e-6)
|
||||
intrinsics = torch.zeros(pose_encoding.shape[:2] + (3, 3), device=pose_encoding.device, dtype=pose_encoding.dtype)
|
||||
intrinsics[..., 0, 0] = fx
|
||||
intrinsics[..., 1, 1] = fy
|
||||
intrinsics[..., 0, 2] = W / 2
|
||||
intrinsics[..., 1, 2] = H / 2
|
||||
intrinsics[..., 2, 2] = 1.0
|
||||
return extrinsics, intrinsics
|
||||
@ -106,11 +106,11 @@ class Ideogram4EmbedScalar(nn.Module):
|
||||
self.mlp_in = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
self.mlp_out = operations.Linear(dim, dim, bias=True, dtype=dtype, device=device)
|
||||
|
||||
def forward(self, x):
|
||||
def forward(self, x, dtype):
|
||||
x = x.to(torch.float32)
|
||||
scaled = 1e4 * (x - self.range_min) / (self.range_max - self.range_min)
|
||||
emb = _sinusoidal_embedding(scaled, self.dim)
|
||||
emb = emb.to(self.mlp_in.weight.dtype)
|
||||
emb = emb.to(dtype)
|
||||
emb = F.silu(self.mlp_in(emb))
|
||||
return self.mlp_out(emb)
|
||||
|
||||
@ -161,7 +161,7 @@ class Ideogram4Transformer(nn.Module):
|
||||
x = x * output_image_mask
|
||||
h = self.input_proj(x) * output_image_mask
|
||||
|
||||
t_cond = self.t_embedding(t)
|
||||
t_cond = self.t_embedding(t, dtype=x.dtype)
|
||||
if t.dim() == 1:
|
||||
t_cond = t_cond.unsqueeze(1)
|
||||
adaln_input = F.silu(self.adaln_proj(t_cond))
|
||||
|
||||
@ -8,6 +8,7 @@ import torch.nn.functional as F
|
||||
from einops import rearrange, repeat
|
||||
from comfy.ldm.lightricks.model import Timesteps
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.modules.attention import optimized_attention_masked
|
||||
import comfy.model_management
|
||||
import comfy.ldm.common_dit
|
||||
@ -17,9 +18,7 @@ def apply_rotary_emb(x, freqs_cis):
|
||||
if x.shape[1] == 0:
|
||||
return x
|
||||
|
||||
t_ = x.reshape(*x.shape[:-1], -1, 1, 2)
|
||||
t_out = freqs_cis[..., 0] * t_[..., 0] + freqs_cis[..., 1] * t_[..., 1]
|
||||
return t_out.reshape(*x.shape).to(dtype=x.dtype)
|
||||
return apply_rope1(x, freqs_cis)
|
||||
|
||||
|
||||
def swiglu(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
||||
|
||||
@ -8,7 +8,7 @@ from einops import rearrange
|
||||
|
||||
from comfy.ldm.modules.attention import optimized_attention
|
||||
from comfy.ldm.flux.layers import EmbedND
|
||||
from comfy.ldm.flux.math import apply_rope1
|
||||
from comfy.ldm.flux.math import apply_rope1, rope
|
||||
import comfy.ldm.common_dit
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -570,6 +570,14 @@ class WanModel(torch.nn.Module):
|
||||
full_ref = self.ref_conv(full_ref).flatten(2).transpose(1, 2)
|
||||
x = torch.concat((full_ref, x), dim=1)
|
||||
|
||||
# In-context reference (Bernini)
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
main_len = x.shape[1]
|
||||
if context_latents is not None:
|
||||
for lat in context_latents:
|
||||
cl = self.patch_embedding(lat.float().to(x.device)).to(x.dtype).flatten(2).transpose(1, 2)
|
||||
x = torch.cat([x, cl], dim=1)
|
||||
|
||||
# context
|
||||
context = self.text_embedding(context)
|
||||
|
||||
@ -599,6 +607,9 @@ class WanModel(torch.nn.Module):
|
||||
# head
|
||||
x = self.head(x, e)
|
||||
|
||||
if context_latents is not None:
|
||||
x = x[:, :main_len]
|
||||
|
||||
if full_ref is not None:
|
||||
x = x[:, full_ref.shape[1]:]
|
||||
|
||||
@ -606,7 +617,7 @@ class WanModel(torch.nn.Module):
|
||||
x = self.unpatchify(x, grid_sizes)
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}):
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, transformer_options={}, source_id=0):
|
||||
patch_size = self.patch_size
|
||||
t_len = ((t + (patch_size[0] // 2)) // patch_size[0])
|
||||
h_len = ((h + (patch_size[1] // 2)) // patch_size[1])
|
||||
@ -638,6 +649,13 @@ class WanModel(torch.nn.Module):
|
||||
img_ids = img_ids.reshape(1, -1, img_ids.shape[-1])
|
||||
|
||||
freqs = self.rope_embedder(img_ids).movedim(1, 2)
|
||||
|
||||
# In-context reference: a non-zero source_id composes an extra rotation into the spatial rope
|
||||
if source_id:
|
||||
d = self.dim // self.num_heads
|
||||
pos = torch.tensor([[float(source_id)]], device=freqs.device, dtype=torch.float32)
|
||||
id_rot = rope(pos, d, self.rope_embedder.theta).reshape(1, 1, 1, d // 2, 2, 2).to(freqs.dtype)
|
||||
freqs = torch.einsum('...ij,...jk->...ik', freqs, id_rot)
|
||||
return freqs
|
||||
|
||||
def forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, **kwargs):
|
||||
@ -661,6 +679,15 @@ class WanModel(torch.nn.Module):
|
||||
t_len += 1
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options)
|
||||
|
||||
# In-context reference: one rope block per stream, each with it's own source_id (1, 2, ...) to distinguish from the target (id 0).
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
if context_latents is not None:
|
||||
context_latents = [comfy.ldm.common_dit.pad_to_patch_size(lat, self.patch_size) for lat in context_latents]
|
||||
for i, lat in enumerate(context_latents):
|
||||
freqs = torch.cat([freqs, self.rope_encode(lat.shape[-3], lat.shape[-2], lat.shape[-1], device=x.device, dtype=x.dtype, transformer_options=transformer_options, source_id=i + 1)], dim=1)
|
||||
kwargs = {**kwargs, "context_latents": context_latents}
|
||||
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
def unpatchify(self, x, grid_sizes):
|
||||
@ -1631,13 +1658,15 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
self.patch_embedding_pose = operations.Conv3d(in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, **kwargs):
|
||||
def forward_orig(self, x, t, context, clip_fea=None, freqs=None, transformer_options={}, pose_latents=None, reference_latent=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||
|
||||
if reference_latent is not None:
|
||||
x = torch.cat((reference_latent, x), dim=2)
|
||||
|
||||
# embeddings
|
||||
x = self.patch_embedding(x.float()).to(x.dtype)
|
||||
if ref_mask_latents is not None: # SCAIL-2 additive mask stream
|
||||
x = x + self.patch_embedding_mask(ref_mask_latents.float()).to(x.dtype)
|
||||
grid_sizes = x.shape[2:]
|
||||
transformer_options["grid_sizes"] = grid_sizes
|
||||
x = x.flatten(2).transpose(1, 2)
|
||||
@ -1645,6 +1674,8 @@ class SCAILWanModel(WanModel):
|
||||
scail_pose_seq_len = 0
|
||||
if pose_latents is not None:
|
||||
scail_x = self.patch_embedding_pose(pose_latents.float()).to(x.dtype)
|
||||
if sam_latents is not None: # SCAIL-2 additive mask stream
|
||||
scail_x = scail_x + self.patch_embedding_mask(sam_latents.float()).to(x.dtype)
|
||||
scail_x = scail_x.flatten(2).transpose(1, 2)
|
||||
scail_pose_seq_len = scail_x.shape[1]
|
||||
x = torch.cat([x, scail_x], dim=1)
|
||||
@ -1695,7 +1726,36 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
return x
|
||||
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, transformer_options={}):
|
||||
# ref_mask_flag is a scalar bool (CONDConstant, SCAIL-2 only). False => replacement mode,
|
||||
# which places ref/pose via H/W rope shifts instead of the animation-mode temporal offset.
|
||||
def rope_encode(self, t, h, w, t_start=0, steps_t=None, steps_h=None, steps_w=None, device=None, dtype=None, pose_latents=None, reference_latent=None, ref_mask_flag=None, transformer_options={}):
|
||||
if ref_mask_flag is not None and not bool(ref_mask_flag):
|
||||
REF_ROPE_H = 120.0
|
||||
POSE_ROPE_W = 120.0
|
||||
|
||||
ref_t_patches = 0
|
||||
if reference_latent is not None:
|
||||
ref_t_patches = (reference_latent.shape[2] + (self.patch_size[0] // 2)) // self.patch_size[0]
|
||||
main_t_patches = t - ref_t_patches
|
||||
|
||||
parts = []
|
||||
if ref_t_patches > 0:
|
||||
ref_tf = {"rope_options": {"shift_y": REF_ROPE_H, "shift_x": 0.0, "scale_y": 1.0, "scale_x": 1.0}}
|
||||
parts.append(super().rope_encode(ref_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=ref_tf))
|
||||
if main_t_patches > 0:
|
||||
parts.append(super().rope_encode(main_t_patches, h, w, t_start=0, device=device, dtype=dtype, transformer_options=transformer_options))
|
||||
|
||||
if pose_latents is not None:
|
||||
F_pose, H_pose, W_pose = pose_latents.shape[-3], pose_latents.shape[-2], pose_latents.shape[-1]
|
||||
h_scale = h / H_pose
|
||||
w_scale = w / W_pose
|
||||
h_shift = (h_scale - 1) / 2
|
||||
w_shift = (w_scale - 1) / 2
|
||||
pose_tf = {"rope_options": {"shift_y": h_shift, "shift_x": POSE_ROPE_W + w_shift, "scale_y": h_scale, "scale_x": w_scale}}
|
||||
parts.append(super().rope_encode(F_pose, H_pose, W_pose, t_start=0, device=device, dtype=dtype, transformer_options=pose_tf))
|
||||
|
||||
return torch.cat(parts, dim=1)
|
||||
|
||||
main_freqs = super().rope_encode(t, h, w, t_start=t_start, steps_t=steps_t, steps_h=steps_h, steps_w=steps_w, device=device, dtype=dtype, transformer_options=transformer_options)
|
||||
|
||||
if pose_latents is None:
|
||||
@ -1719,12 +1779,16 @@ class SCAILWanModel(WanModel):
|
||||
|
||||
return torch.cat([main_freqs, pose_freqs], dim=1)
|
||||
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, **kwargs):
|
||||
def _forward(self, x, timestep, context, clip_fea=None, time_dim_concat=None, transformer_options={}, pose_latents=None, ref_mask_latents=None, sam_latents=None, **kwargs):
|
||||
bs, c, t, h, w = x.shape
|
||||
x = comfy.ldm.common_dit.pad_to_patch_size(x, self.patch_size)
|
||||
|
||||
if pose_latents is not None:
|
||||
pose_latents = comfy.ldm.common_dit.pad_to_patch_size(pose_latents, self.patch_size)
|
||||
if ref_mask_latents is not None: # SCAIL-2
|
||||
ref_mask_latents = comfy.ldm.common_dit.pad_to_patch_size(ref_mask_latents, self.patch_size)
|
||||
if sam_latents is not None: # SCAIL-2
|
||||
sam_latents = comfy.ldm.common_dit.pad_to_patch_size(sam_latents, self.patch_size)
|
||||
|
||||
t_len = t
|
||||
if time_dim_concat is not None:
|
||||
@ -1737,5 +1801,15 @@ class SCAILWanModel(WanModel):
|
||||
reference_latent = comfy.ldm.common_dit.pad_to_patch_size(kwargs.pop("reference_latent"), self.patch_size)
|
||||
t_len += reference_latent.shape[2]
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, **kwargs)[:, :, :t, :h, :w]
|
||||
ref_mask_flag = kwargs.pop("ref_mask_flag", None) # SCAIL-2
|
||||
|
||||
freqs = self.rope_encode(t_len, h, w, device=x.device, dtype=x.dtype, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_flag=ref_mask_flag)
|
||||
return self.forward_orig(x, timestep, context, clip_fea=clip_fea, freqs=freqs, transformer_options=transformer_options, pose_latents=pose_latents, reference_latent=reference_latent, ref_mask_latents=ref_mask_latents, sam_latents=sam_latents, **kwargs)[:, :, :t, :h, :w]
|
||||
|
||||
|
||||
class SCAIL2WanModel(SCAILWanModel):
|
||||
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
||||
|
||||
def __init__(self, model_type="scail2", patch_size=(1, 2, 2), in_dim=20, mask_in_dim=28, dim=5120, operations=None, device=None, dtype=None, **kwargs):
|
||||
super().__init__(model_type=model_type, patch_size=patch_size, in_dim=in_dim, dim=dim, operations=operations, device=device, dtype=dtype, **kwargs)
|
||||
self.patch_embedding_mask = operations.Conv3d(mask_in_dim, dim, kernel_size=patch_size, stride=patch_size, device=device, dtype=torch.float32)
|
||||
|
||||
@ -357,6 +357,12 @@ def model_lora_keys_unet(model, key_map={}):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["transformer.{}".format(key_lora)] = k
|
||||
|
||||
if isinstance(model, (comfy.model_base.LTXV, comfy.model_base.LTXAV)):
|
||||
for k in sdk:
|
||||
if k.startswith("diffusion_model.") and k.endswith(".weight"):
|
||||
key_lora = k[len("diffusion_model."):-len(".weight")]
|
||||
key_map["{}".format(key_lora)] = k
|
||||
|
||||
return key_map
|
||||
|
||||
|
||||
|
||||
@ -65,6 +65,7 @@ import comfy.ldm.ernie.model
|
||||
import comfy.ldm.sam3.detector
|
||||
import comfy.ldm.hidream_o1.model
|
||||
from comfy.ldm.hidream_o1.conditioning import build_extra_conds
|
||||
import comfy.ldm.depth_anything_3.model
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.patcher_extension
|
||||
@ -1518,8 +1519,26 @@ class WAN21(BaseModel):
|
||||
if reference_latents is not None:
|
||||
out['reference_latent'] = comfy.conds.CONDRegular(self.process_latent_in(reference_latents[-1])[:, :, 0])
|
||||
|
||||
# In-context reference conditioning (Bernini)
|
||||
context_latents = kwargs.get("context_latents", None)
|
||||
if context_latents is not None:
|
||||
out['context_latents'] = comfy.conds.CONDList([self.process_latent_in(l) for l in context_latents])
|
||||
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
# In-context cond slicing (Bernini)
|
||||
if cond_key == "context_latents" and isinstance(getattr(cond_value, "cond", None), list):
|
||||
dim = window.dim
|
||||
out = []
|
||||
for lat in cond_value.cond:
|
||||
if lat.ndim > dim and lat.shape[dim] > 1 and lat.shape[dim] == x_in.shape[dim]:
|
||||
out.append(window.get_tensor(lat, device, dim=dim, retain_index_list=retain_index_list))
|
||||
else:
|
||||
out.append(lat.to(device))
|
||||
return cond_value._copy_with(out)
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
|
||||
class WAN21_CausalAR(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
@ -1754,6 +1773,97 @@ class WAN21_SCAIL(WAN21):
|
||||
|
||||
return out
|
||||
|
||||
class WAN21_SCAIL2(WAN21_SCAIL):
|
||||
"""SCAIL-2: SCAIL-Preview + an additive binary multi-identity mask stream."""
|
||||
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=False, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model.SCAIL2WanModel)
|
||||
self.memory_usage_factor_conds = ("reference_latent", "pose_latents", "ref_mask_latents", "sam_latents")
|
||||
self.memory_usage_shape_process = {
|
||||
"pose_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
|
||||
"sam_latents": lambda shape: [shape[0], shape[1], 1.5, shape[-2], shape[-1]],
|
||||
}
|
||||
self.image_to_video = image_to_video
|
||||
|
||||
def extra_conds(self, **kwargs):
|
||||
out = super().extra_conds(**kwargs)
|
||||
|
||||
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
|
||||
if driving_mask_28ch is not None:
|
||||
out['sam_latents'] = comfy.conds.CONDRegular(driving_mask_28ch.movedim(1, 2).contiguous())
|
||||
|
||||
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
|
||||
if ref_mask_28ch is not None:
|
||||
out['ref_mask_latents'] = comfy.conds.CONDRegular(ref_mask_28ch.movedim(1, 2).contiguous())
|
||||
|
||||
ref_mask_flag = kwargs.get("ref_mask_flag", None)
|
||||
if ref_mask_flag is not None:
|
||||
out['ref_mask_flag'] = comfy.conds.CONDConstant(ref_mask_flag)
|
||||
|
||||
return out
|
||||
|
||||
def extra_conds_shapes(self, **kwargs):
|
||||
out = super().extra_conds_shapes(**kwargs)
|
||||
driving_mask_28ch = kwargs.get("driving_mask_28ch", None)
|
||||
if driving_mask_28ch is not None:
|
||||
s = driving_mask_28ch.shape
|
||||
out['sam_latents'] = [s[0], 28, s[1], s[3], s[4]]
|
||||
ref_mask_28ch = kwargs.get("ref_mask_28ch", None)
|
||||
if ref_mask_28ch is not None:
|
||||
s = ref_mask_28ch.shape
|
||||
out['ref_mask_latents'] = [s[0], 28, s[1], s[3], s[4]]
|
||||
return out
|
||||
|
||||
def resize_cond_for_context_window(self, cond_key, cond_value, window, x_in, device, retain_index_list=[]):
|
||||
if cond_key in ("sam_latents", "pose_latents"):
|
||||
# Return sliced view omitting retain_index_list
|
||||
return comfy.context_windows.slice_cond(cond_value, window, x_in, device, temporal_dim=2, temporal_offset=0)
|
||||
if cond_key == "ref_mask_latents" and hasattr(cond_value, "cond") and isinstance(cond_value.cond, torch.Tensor):
|
||||
# The ref mask is just a single frame padded with frames of zeros, so just grab the first frames for all windows
|
||||
full_ref_mask = cond_value.cond
|
||||
video_frame_count = x_in.shape[2]
|
||||
if full_ref_mask.shape[2] != video_frame_count + 1:
|
||||
return None
|
||||
window_length = len(window.index_list)
|
||||
|
||||
# Account for the causal anchor frame if it exists
|
||||
anchor_index = getattr(window, "causal_anchor_index", None)
|
||||
if anchor_index is not None and anchor_index >= 0:
|
||||
window_length += 1
|
||||
|
||||
window_ref_mask = full_ref_mask[:, :, :window_length + 1].to(device)
|
||||
return cond_value._copy_with(window_ref_mask)
|
||||
|
||||
return super().resize_cond_for_context_window(cond_key, cond_value, window, x_in, device, retain_index_list=retain_index_list)
|
||||
|
||||
def concat_cond(self, **kwargs):
|
||||
# The 4 extra channels are the history_mask (1 at clean-anchor frames).
|
||||
noise = kwargs.get("noise", None)
|
||||
extra_channels = self.diffusion_model.patch_embedding.weight.shape[1] - noise.shape[1]
|
||||
if extra_channels != 4:
|
||||
return super().concat_cond(**kwargs)
|
||||
|
||||
mask = kwargs.get("concat_mask", kwargs.get("denoise_mask", None))
|
||||
if mask is None:
|
||||
return torch.zeros_like(noise)[:, :4]
|
||||
|
||||
device = kwargs["device"]
|
||||
if mask.shape[1] != 4:
|
||||
mask = torch.mean(mask, dim=1, keepdim=True)
|
||||
mask = 1.0 - mask
|
||||
mask = utils.common_upscale(mask.to(device), noise.shape[-1], noise.shape[-2], "bilinear", "center")
|
||||
if mask.shape[-3] < noise.shape[-3]:
|
||||
mask = torch.nn.functional.pad(mask, (0, 0, 0, 0, 0, noise.shape[-3] - mask.shape[-3]), mode='constant', value=0)
|
||||
if mask.shape[1] == 1:
|
||||
mask = mask.repeat(1, 4, 1, 1, 1)
|
||||
mask = utils.resize_to_batch_size(mask, noise.shape[0])
|
||||
return mask
|
||||
|
||||
def scale_latent_inpaint(self, sigma, noise, latent_image, **kwargs):
|
||||
# Hold anchor constant across all sigmas instead of base sigma*noise + (1-sigma)*latent_image.
|
||||
return latent_image
|
||||
|
||||
|
||||
class WAN22_WanDancer(WAN21):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, image_to_video=True, device=None):
|
||||
super(WAN21, self).__init__(model_config, model_type, device=device, unet_model=comfy.ldm.wan.model_wandancer.WanDancerModel)
|
||||
@ -2227,6 +2337,12 @@ class RT_DETR_v4(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.rt_detr.rtdetr_v4.RTv4)
|
||||
|
||||
|
||||
class DepthAnything3(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device,
|
||||
unet_model=comfy.ldm.depth_anything_3.model.DepthAnything3Net)
|
||||
|
||||
class ErnieImage(BaseModel):
|
||||
def __init__(self, model_config, model_type=ModelType.FLOW, device=None):
|
||||
super().__init__(model_config, model_type, device=device, unet_model=comfy.ldm.ernie.model.ErnieImageModel)
|
||||
|
||||
@ -630,6 +630,8 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["model_type"] = "humo"
|
||||
elif '{}face_adapter.fuser_blocks.0.k_norm.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "animate"
|
||||
elif '{}patch_embedding_mask.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "scail2"
|
||||
elif '{}patch_embedding_pose.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config["model_type"] = "scail"
|
||||
elif '{}patch_embedding_global.weight'.format(key_prefix) in state_dict_keys:
|
||||
@ -860,6 +862,95 @@ def detect_unet_config(state_dict, key_prefix, metadata=None):
|
||||
dit_config["enc_h"] = state_dict['{}encoder.pan_blocks.1.cv4.conv.weight'.format(key_prefix)].shape[0]
|
||||
return dit_config
|
||||
|
||||
# Depth Anything 3 (repackaged to ComfyUI's native Dinov2Model layout via scripts/convert_da3.py)
|
||||
if '{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix) in state_dict_keys:
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "DepthAnything3"
|
||||
|
||||
patch_w = state_dict['{}backbone.embeddings.patch_embeddings.projection.weight'.format(key_prefix)]
|
||||
embed_dim = patch_w.shape[0]
|
||||
depth = count_blocks(state_dict_keys, '{}backbone.encoder.layer.'.format(key_prefix) + '{}.')
|
||||
|
||||
# Backbone preset is determined by embed_dim (matches vits/vitb/vitl/vitg).
|
||||
backbone_name = {384: "vits", 768: "vitb", 1024: "vitl", 1536: "vitg"}.get(embed_dim)
|
||||
if backbone_name is None:
|
||||
return None
|
||||
dit_config["backbone_name"] = backbone_name
|
||||
|
||||
# Detect DA3 extensions on top of vanilla DINOv2.
|
||||
has_camera_token = '{}backbone.embeddings.camera_token'.format(key_prefix) in state_dict_keys
|
||||
# qk-norm shows up as `attention.q_norm.weight` on enabled blocks.
|
||||
qknorm_indices = [
|
||||
i for i in range(depth)
|
||||
if '{}backbone.encoder.layer.{}.attention.q_norm.weight'.format(key_prefix, i) in state_dict_keys
|
||||
]
|
||||
qknorm_start = qknorm_indices[0] if qknorm_indices else -1
|
||||
|
||||
# The DA3 main-series configs always set alt_start == qknorm_start == rope_start.
|
||||
# cat_token=True is implied by the presence of camera_token.
|
||||
if has_camera_token:
|
||||
dit_config["alt_start"] = qknorm_start
|
||||
dit_config["rope_start"] = qknorm_start
|
||||
dit_config["qknorm_start"] = qknorm_start
|
||||
dit_config["cat_token"] = True
|
||||
else:
|
||||
dit_config["alt_start"] = -1
|
||||
dit_config["rope_start"] = -1
|
||||
dit_config["qknorm_start"] = -1
|
||||
dit_config["cat_token"] = False
|
||||
|
||||
# Detect head type and config.
|
||||
has_aux = '{}head.scratch.refinenet1_aux.out_conv.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["head_dim_in"] = state_dict['{}head.projects.0.weight'.format(key_prefix)].shape[1]
|
||||
dit_config["head_features"] = state_dict['{}head.scratch.refinenet1.out_conv.weight'.format(key_prefix)].shape[0]
|
||||
dit_config["head_out_channels"] = [
|
||||
state_dict['{}head.projects.{}.weight'.format(key_prefix, i)].shape[0]
|
||||
for i in range(4)
|
||||
]
|
||||
if has_aux:
|
||||
# DualDPT: dim_in = 2 * embed_dim (because cat_token doubles token width).
|
||||
dit_config["head_type"] = "dualdpt"
|
||||
dit_config["head_output_dim"] = 2
|
||||
dit_config["head_use_sky_head"] = False
|
||||
else:
|
||||
dit_config["head_type"] = "dpt"
|
||||
dit_config["head_output_dim"] = state_dict[
|
||||
'{}head.scratch.output_conv2.2.weight'.format(key_prefix)
|
||||
].shape[0]
|
||||
dit_config["head_use_sky_head"] = (
|
||||
'{}head.scratch.sky_output_conv2.0.weight'.format(key_prefix) in state_dict_keys
|
||||
)
|
||||
|
||||
# out_layers: hard-coded per upstream YAML config (depth-aware default).
|
||||
if depth >= 24:
|
||||
# vitl: depths used vary between DA3-Large (DualDPT) and Mono/Metric (DPT).
|
||||
if has_aux:
|
||||
dit_config["out_layers"] = [11, 15, 19, 23]
|
||||
else:
|
||||
dit_config["out_layers"] = [4, 11, 17, 23]
|
||||
else:
|
||||
# vits/vitb: 12 blocks
|
||||
dit_config["out_layers"] = [5, 7, 9, 11]
|
||||
|
||||
# Camera encoder/decoder presence (multi-view + pose path).
|
||||
has_cam_enc = '{}cam_enc.token_norm.weight'.format(key_prefix) in state_dict_keys
|
||||
has_cam_dec = '{}cam_dec.fc_t.weight'.format(key_prefix) in state_dict_keys
|
||||
dit_config["has_cam_enc"] = has_cam_enc
|
||||
dit_config["has_cam_dec"] = has_cam_dec
|
||||
if has_cam_enc:
|
||||
cam_enc_w = state_dict.get(
|
||||
'{}cam_enc.pose_branch.fc2.weight'.format(key_prefix)
|
||||
)
|
||||
if cam_enc_w is not None:
|
||||
dit_config["cam_dim_out"] = cam_enc_w.shape[0]
|
||||
if has_cam_dec:
|
||||
cam_dec_w = state_dict.get(
|
||||
'{}cam_dec.fc_t.weight'.format(key_prefix)
|
||||
)
|
||||
if cam_dec_w is not None:
|
||||
dit_config["cam_dec_dim_in"] = cam_dec_w.shape[1]
|
||||
return dit_config
|
||||
|
||||
if '{}layers.0.mlp.linear_fc2.weight'.format(key_prefix) in state_dict_keys: # Ernie Image
|
||||
dit_config = {}
|
||||
dit_config["image_model"] = "ernie"
|
||||
|
||||
@ -534,8 +534,10 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available() and PerformanceFeature.AutoTune in args.fast:
|
||||
torch.backends.cudnn.benchmark = True
|
||||
|
||||
def set_cudnn_benchmark():
|
||||
if torch.cuda.is_available() and torch.backends.cudnn.is_available():
|
||||
torch.backends.cudnn.benchmark = PerformanceFeature.AutoTune in args.fast
|
||||
|
||||
try:
|
||||
if torch_version_numeric >= (2, 5):
|
||||
@ -641,6 +643,8 @@ def free_pins(size, evict_active=False):
|
||||
return freed_total
|
||||
|
||||
def ensure_pin_budget(size, evict_active=False):
|
||||
if args.high_ram:
|
||||
return True
|
||||
if args.fast_disk:
|
||||
shortfall = TOTAL_PINNED_MEMORY + size - MAX_PINNED_MEMORY
|
||||
else:
|
||||
@ -958,8 +962,6 @@ def loaded_models(only_currently_used=False):
|
||||
def cleanup_models_gc():
|
||||
do_gc = False
|
||||
|
||||
reset_cast_buffers()
|
||||
|
||||
for i in range(len(current_loaded_models)):
|
||||
cur = current_loaded_models[i]
|
||||
if cur.is_dead():
|
||||
@ -1496,6 +1498,8 @@ if not args.disable_pinned_memory:
|
||||
PINNING_ALLOWED_TYPES = set(["Tensor", "Parameter", "QuantizedTensor"])
|
||||
|
||||
def pinned_hostbuf_size(size):
|
||||
if args.high_ram:
|
||||
return max(0, int(size * 2))
|
||||
return max(0, int(min(size, MAX_PINNED_MEMORY) * 2))
|
||||
|
||||
def discard_cuda_async_error():
|
||||
|
||||
@ -379,10 +379,11 @@ class ModelPatcher:
|
||||
def get_clone_model_override(self):
|
||||
return self.model, (self.backup, self.backup_buffers, self.object_patches_backup, self.pinned)
|
||||
|
||||
def clone(self, disable_dynamic=False, model_override=None):
|
||||
def clone(self, disable_dynamic=False, model_override=None, force_deepcopy=False):
|
||||
class_ = self.__class__
|
||||
if self.is_dynamic() and disable_dynamic:
|
||||
class_ = ModelPatcher
|
||||
if self.is_dynamic() and disable_dynamic or force_deepcopy:
|
||||
if self.is_dynamic() and disable_dynamic:
|
||||
class_ = ModelPatcher
|
||||
if model_override is None:
|
||||
if self.cached_patcher_init is None:
|
||||
raise RuntimeError("Cannot create non-dynamic delegate: cached_patcher_init is not initialized.")
|
||||
|
||||
18
comfy/ops.py
18
comfy/ops.py
@ -180,7 +180,7 @@ def cast_modules_with_vbar(comfy_modules, dtype, device, bias_dtype, non_blockin
|
||||
if pin is not None:
|
||||
cast_maybe_lowvram_patch([pin], dest, offload_stream)
|
||||
return
|
||||
if signature is None:
|
||||
if signature is None or args.high_ram:
|
||||
comfy.pinned_memory.pin_memory(m, subset=subset, size=size)
|
||||
pin = comfy.pinned_memory.get_pin(m, subset=subset)
|
||||
cast_maybe_lowvram_patch(source, pin, offload_stream, xfer_dest2=dest)
|
||||
@ -299,21 +299,21 @@ def cast_bias_weight(s, input=None, dtype=None, device=None, bias_dtype=None, of
|
||||
|
||||
non_blocking = comfy.model_management.device_supports_non_blocking(device)
|
||||
|
||||
if hasattr(s, "_v"):
|
||||
if hasattr(s, "_v") and comfy.model_management.is_device_cpu(device):
|
||||
|
||||
#vbar doesn't support CPU weights, but some custom nodes have weird paths
|
||||
#that might switch the layer to the CPU and expect it to work. We have to take
|
||||
#a clone conservatively as we are mmapped and some SFT files are packed misaligned
|
||||
#If you are a custom node author reading this, please move your layer to the GPU
|
||||
#or declare your ModelPatcher as CPU in the first place.
|
||||
if comfy.model_management.is_device_cpu(device):
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||
materialize_meta_param(s, ["weight", "bias"])
|
||||
weight = s.weight.to(dtype=dtype, copy=True)
|
||||
if isinstance(weight, QuantizedTensor):
|
||||
weight = weight.dequantize()
|
||||
bias = s.bias.to(dtype=bias_dtype, copy=True) if s.bias is not None else None
|
||||
return format_return((weight, bias, (None, None, None)), offloadable)
|
||||
|
||||
elif hasattr(s, "_v") and s.weight.device != device:
|
||||
prefetched = hasattr(s, "_prefetch")
|
||||
offload_stream = None
|
||||
offload_device = None
|
||||
|
||||
@ -1450,6 +1450,17 @@ class WAN21_SCAIL(WAN21_T2V):
|
||||
out = model_base.WAN21_SCAIL(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
|
||||
class WAN21_SCAIL2(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
"model_type": "scail2",
|
||||
}
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
out = model_base.WAN21_SCAIL2(self, image_to_video=False, device=device)
|
||||
return out
|
||||
|
||||
class WAN22_WanDancer(WAN21_T2V):
|
||||
unet_config = {
|
||||
"image_model": "wan2.1",
|
||||
@ -2045,6 +2056,23 @@ class RT_DETR_v4(supported_models_base.BASE):
|
||||
return None
|
||||
|
||||
|
||||
class DepthAnything3(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "DepthAnything3",
|
||||
}
|
||||
|
||||
# Mono path: no num_heads / num_head_channels needed.
|
||||
unet_extra_config = {}
|
||||
|
||||
supported_inference_dtypes = [torch.float16, torch.bfloat16, torch.float32]
|
||||
|
||||
def get_model(self, state_dict, prefix="", device=None):
|
||||
return model_base.DepthAnything3(self, device=device)
|
||||
|
||||
def clip_target(self, state_dict={}):
|
||||
return None
|
||||
|
||||
|
||||
class ErnieImage(supported_models_base.BASE):
|
||||
unet_config = {
|
||||
"image_model": "ernie",
|
||||
@ -2259,6 +2287,7 @@ models = [
|
||||
WAN22_Animate,
|
||||
WAN21_FlowRVS,
|
||||
WAN21_SCAIL,
|
||||
WAN21_SCAIL2,
|
||||
WAN22_WanDancer,
|
||||
Hunyuan3Dv2mini,
|
||||
Hunyuan3Dv2,
|
||||
@ -2286,4 +2315,5 @@ models = [
|
||||
CogVideoX_I2V,
|
||||
CogVideoX_T2V,
|
||||
SVD_img2vid,
|
||||
DepthAnything3,
|
||||
]
|
||||
|
||||
@ -27,10 +27,13 @@ class VideoInput(ABC):
|
||||
path: Union[str, IO[bytes]],
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
"""
|
||||
Abstract method to save the video input to a file.
|
||||
|
||||
bit_depth selects the encoded bit depth; None keeps the video's native depth.
|
||||
"""
|
||||
pass
|
||||
|
||||
@ -83,6 +86,14 @@ class VideoInput(ABC):
|
||||
components = self.get_components()
|
||||
return components.images.shape[2], components.images.shape[1]
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
"""
|
||||
Returns the bit depth of the video (e.g. 8 or 10).
|
||||
|
||||
Default implementation returns 8; subclasses report their real depth.
|
||||
"""
|
||||
return 8
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
|
||||
@ -52,6 +52,12 @@ def get_open_write_kwargs(
|
||||
return open_kwargs
|
||||
|
||||
|
||||
def video_stream_bit_depth(stream) -> int:
|
||||
if stream is None or stream.format is None or not stream.format.components:
|
||||
return 8
|
||||
return max(component.bits for component in stream.format.components)
|
||||
|
||||
|
||||
class VideoFromFile(VideoInput):
|
||||
"""
|
||||
Class representing video input from a file.
|
||||
@ -97,6 +103,13 @@ class VideoFromFile(VideoInput):
|
||||
return stream.width, stream.height
|
||||
raise ValueError(f"No video stream found in file '{self.__file}'")
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode="r") as container:
|
||||
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||
return video_stream_bit_depth(video_stream)
|
||||
|
||||
def get_duration(self) -> float:
|
||||
"""
|
||||
Returns the duration of the video in seconds.
|
||||
@ -257,6 +270,7 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
image_format = 'gbrpf32le'
|
||||
process_image_format = lambda a: a
|
||||
align_graph = None
|
||||
audio = None
|
||||
|
||||
streams = [video_stream]
|
||||
@ -310,7 +324,24 @@ class VideoFromFile(VideoInput):
|
||||
|
||||
checked_alpha = True
|
||||
|
||||
img = frame.to_ndarray(format=image_format) # shape: (H, W, 4)
|
||||
# Fix non-deterministic video decode when the video width is not a multiple of 32
|
||||
# For non-yuvj pixel formats (all H.264/H.265 video)
|
||||
if image_format in ('gbrpf32le', 'gbrapf32le') and frame.width % 32 != 0:
|
||||
if align_graph is None:
|
||||
pad_w = ((frame.width + 31) // 32) * 32
|
||||
g = av.filter.Graph()
|
||||
g_src = g.add_buffer(width=frame.width, height=frame.height,
|
||||
format=frame.format.name, time_base=video_stream.time_base)
|
||||
g_pad = g.add('pad', f'{pad_w}:{frame.height}:0:0')
|
||||
g_sink = g.add('buffersink')
|
||||
g_src.link_to(g_pad)
|
||||
g_pad.link_to(g_sink)
|
||||
g.configure()
|
||||
align_graph = (g, g_src, g_sink)
|
||||
align_graph[1].push(frame)
|
||||
img = np.ascontiguousarray(align_graph[2].pull().to_ndarray(format=image_format)[:, :frame.width])
|
||||
else:
|
||||
img = frame.to_ndarray(format=image_format)
|
||||
if frame.rotation != 0:
|
||||
k = int(round(frame.rotation // 90))
|
||||
img = np.rot90(img, k=k, axes=(0, 1)).copy()
|
||||
@ -377,25 +408,32 @@ class VideoFromFile(VideoInput):
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
if isinstance(self.__file, io.BytesIO):
|
||||
self.__file.seek(0) # Reset the BytesIO object to the beginning
|
||||
with av.open(self.__file, mode='r') as container:
|
||||
container_format = container.format.name
|
||||
video_encoding = container.streams.video[0].codec.name if len(container.streams.video) > 0 else None
|
||||
video_stream = container.streams.video[0] if len(container.streams.video) > 0 else None
|
||||
video_encoding = video_stream.codec.name if video_stream is not None else None
|
||||
source_bit_depth = video_stream_bit_depth(video_stream)
|
||||
reuse_streams = True
|
||||
if format != VideoContainer.AUTO and format not in container_format.split(","):
|
||||
reuse_streams = False
|
||||
if codec != VideoCodec.AUTO and codec != video_encoding and video_encoding is not None:
|
||||
reuse_streams = False
|
||||
if bit_depth is not None and video_encoding is not None and bit_depth != source_bit_depth:
|
||||
reuse_streams = False
|
||||
if self.__start_time or self.__duration:
|
||||
reuse_streams = False
|
||||
|
||||
if not reuse_streams:
|
||||
if bit_depth is None:
|
||||
bit_depth = source_bit_depth
|
||||
components = self.get_components_internal(container)
|
||||
video = VideoFromComponents(components)
|
||||
return video.save_to(
|
||||
path, format=format, codec=codec, metadata=metadata
|
||||
path, format=format, codec=codec, metadata=metadata, bit_depth=bit_depth,
|
||||
)
|
||||
|
||||
streams = container.streams
|
||||
@ -451,8 +489,10 @@ class VideoFromComponents(VideoInput):
|
||||
Class representing video input from tensors.
|
||||
"""
|
||||
|
||||
def __init__(self, components: VideoComponents):
|
||||
def __init__(self, components: VideoComponents, bit_depth: int = 8):
|
||||
self.__components = components
|
||||
# Tensor components have no inherent bit depth; this is the depth used when encoding.
|
||||
self.__bit_depth = bit_depth
|
||||
|
||||
def get_components(self) -> VideoComponents:
|
||||
return VideoComponents(
|
||||
@ -461,18 +501,26 @@ class VideoFromComponents(VideoInput):
|
||||
frame_rate=self.__components.frame_rate,
|
||||
)
|
||||
|
||||
def get_bit_depth(self) -> int:
|
||||
return self.__bit_depth
|
||||
|
||||
def save_to(
|
||||
self,
|
||||
path: str,
|
||||
format: VideoContainer = VideoContainer.AUTO,
|
||||
codec: VideoCodec = VideoCodec.AUTO,
|
||||
metadata: Optional[dict] = None,
|
||||
bit_depth: int | None = None,
|
||||
):
|
||||
"""Save the video to a file path or BytesIO buffer."""
|
||||
if format != VideoContainer.AUTO and format != VideoContainer.MP4:
|
||||
raise ValueError("Only MP4 format is supported for now")
|
||||
if codec != VideoCodec.AUTO and codec != VideoCodec.H264:
|
||||
raise ValueError("Only H264 codec is supported for now")
|
||||
# None means "use the depth this video was created with" (CreateVideo's choice).
|
||||
if bit_depth is None:
|
||||
bit_depth = self.__bit_depth
|
||||
is_10bit = bit_depth >= 10
|
||||
extra_kwargs = {}
|
||||
if isinstance(format, VideoContainer) and format != VideoContainer.AUTO:
|
||||
extra_kwargs["format"] = format.value
|
||||
@ -488,10 +536,11 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
frame_rate = Fraction(round(self.__components.frame_rate * 1000), 1000)
|
||||
# Create a video stream
|
||||
pix_fmt = "yuv420p10le" if is_10bit else "yuv420p"
|
||||
video_stream = output.add_stream('h264', rate=frame_rate)
|
||||
video_stream.width = self.__components.images.shape[2]
|
||||
video_stream.height = self.__components.images.shape[1]
|
||||
video_stream.pix_fmt = 'yuv420p'
|
||||
video_stream.pix_fmt = pix_fmt
|
||||
|
||||
# Create an audio stream
|
||||
audio_sample_rate = 1
|
||||
@ -505,9 +554,14 @@ class VideoFromComponents(VideoInput):
|
||||
|
||||
# Encode video
|
||||
for i, frame in enumerate(self.__components.images):
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format='yuv420p') # Convert to YUV420P as required by h264
|
||||
if is_10bit:
|
||||
# 16-bit RGB keeps float precision through the conversion to 10-bit YUV.
|
||||
img = (frame.float() * 65535).clamp(0, 65535).cpu().numpy().astype(np.uint16) # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format="rgb48le")
|
||||
else:
|
||||
img = (frame * 255).clamp(0, 255).byte().cpu().numpy() # shape: (H, W, 3)
|
||||
frame = av.VideoFrame.from_ndarray(img, format='rgb24')
|
||||
frame = frame.reformat(format=pix_fmt)
|
||||
packet = video_stream.encode(frame)
|
||||
output.mux(packet)
|
||||
|
||||
|
||||
@ -1400,7 +1400,8 @@ class V3Data(TypedDict):
|
||||
class HiddenHolder:
|
||||
def __init__(self, unique_id: str, prompt: Any,
|
||||
extra_pnginfo: Any, dynprompt: Any,
|
||||
auth_token_comfy_org: str, api_key_comfy_org: str, **kwargs):
|
||||
auth_token_comfy_org: str, api_key_comfy_org: str,
|
||||
comfy_usage_source: str = None, **kwargs):
|
||||
self.unique_id = unique_id
|
||||
"""UNIQUE_ID is the unique identifier of the node, and matches the id property of the node on the client side. It is commonly used in client-server communications (see messages)."""
|
||||
self.prompt = prompt
|
||||
@ -1413,6 +1414,8 @@ class HiddenHolder:
|
||||
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||
self.api_key_comfy_org = api_key_comfy_org
|
||||
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||
self.comfy_usage_source = comfy_usage_source
|
||||
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
|
||||
|
||||
def __getattr__(self, key: str):
|
||||
'''If hidden variable not found, return None.'''
|
||||
@ -1429,6 +1432,7 @@ class HiddenHolder:
|
||||
dynprompt=d.get(Hidden.dynprompt, None),
|
||||
auth_token_comfy_org=d.get(Hidden.auth_token_comfy_org, None),
|
||||
api_key_comfy_org=d.get(Hidden.api_key_comfy_org, None),
|
||||
comfy_usage_source=d.get(Hidden.comfy_usage_source, None),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@ -1451,6 +1455,8 @@ class Hidden(str, Enum):
|
||||
"""AUTH_TOKEN_COMFY_ORG is a token acquired from signing into a ComfyOrg account on frontend."""
|
||||
api_key_comfy_org = "API_KEY_COMFY_ORG"
|
||||
"""API_KEY_COMFY_ORG is an API Key generated by ComfyOrg that allows skipping signing into a ComfyOrg account on frontend."""
|
||||
comfy_usage_source = "COMFY_USAGE_SOURCE"
|
||||
"""COMFY_USAGE_SOURCE identifies the client that submitted the prompt (e.g. comfyui-frontend, comfy-cli, comfyui-mcp); forwarded to API nodes' upstream requests via the Comfy-Usage-Source header."""
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -1654,6 +1660,8 @@ class Schema:
|
||||
self.hidden.append(Hidden.auth_token_comfy_org)
|
||||
if Hidden.api_key_comfy_org not in self.hidden:
|
||||
self.hidden.append(Hidden.api_key_comfy_org)
|
||||
if Hidden.comfy_usage_source not in self.hidden:
|
||||
self.hidden.append(Hidden.comfy_usage_source)
|
||||
# if is an output_node, will need prompt and extra_pnginfo
|
||||
if self.is_output_node:
|
||||
if Hidden.prompt not in self.hidden:
|
||||
|
||||
9
comfy_api_nodes/apis/__init__.py
generated
9
comfy_api_nodes/apis/__init__.py
generated
@ -1310,13 +1310,6 @@ class KlingTaskStatus(str, Enum):
|
||||
failed = 'failed'
|
||||
|
||||
|
||||
class KlingTextToVideoModelName(str, Enum):
|
||||
kling_v1 = 'kling-v1'
|
||||
kling_v1_6 = 'kling-v1-6'
|
||||
kling_v2_1_master = 'kling-v2-1-master'
|
||||
kling_v2_5_turbo = 'kling-v2-5-turbo'
|
||||
|
||||
|
||||
class KlingVideoGenAspectRatio(str, Enum):
|
||||
field_16_9 = '16:9'
|
||||
field_9_16 = '9:16'
|
||||
@ -5179,7 +5172,7 @@ class KlingText2VideoRequest(BaseModel):
|
||||
duration: Optional[KlingVideoGenDuration] = '5'
|
||||
external_task_id: Optional[str] = Field(None, description='Customized Task ID')
|
||||
mode: Optional[KlingVideoGenMode] = 'std'
|
||||
model_name: Optional[KlingTextToVideoModelName] = 'kling-v1'
|
||||
model_name: Optional[str] = 'kling-v1'
|
||||
negative_prompt: Optional[str] = Field(
|
||||
None, description='Negative text prompt', max_length=2500
|
||||
)
|
||||
|
||||
@ -67,15 +67,6 @@ class RunwayImageToVideoResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
|
||||
|
||||
class RunwayTaskStatusEnum(str, Enum):
|
||||
SUCCEEDED = 'SUCCEEDED'
|
||||
RUNNING = 'RUNNING'
|
||||
FAILED = 'FAILED'
|
||||
PENDING = 'PENDING'
|
||||
CANCELLED = 'CANCELLED'
|
||||
THROTTLED = 'THROTTLED'
|
||||
|
||||
|
||||
class RunwayTaskStatusResponse(BaseModel):
|
||||
createdAt: datetime = Field(..., description='Task creation timestamp')
|
||||
id: str = Field(..., description='Task ID')
|
||||
@ -86,7 +77,7 @@ class RunwayTaskStatusResponse(BaseModel):
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
status: RunwayTaskStatusEnum
|
||||
status: str = Field(..., description="SUCCEEDED, RUNNING, FAILED, PENDING, CANCELLED or THROTTLED")
|
||||
|
||||
|
||||
class Model4(str, Enum):
|
||||
@ -125,3 +116,144 @@ class RunwayTextToImageRequest(BaseModel):
|
||||
|
||||
class RunwayTextToImageResponse(BaseModel):
|
||||
id: Optional[str] = Field(None, description='Task ID')
|
||||
|
||||
|
||||
class RunwayAleph2IO:
|
||||
"""Custom socket types for chaining Aleph2 guidance images."""
|
||||
|
||||
KEYFRAME = "RUNWAY_ALEPH2_KEYFRAME"
|
||||
PROMPT_IMAGE = "RUNWAY_ALEPH2_PROMPT_IMAGE"
|
||||
|
||||
|
||||
# Keyframe timing modes (anchored to the INPUT video). Stored on the chain item and used to
|
||||
# choose the request model below. The values match the Aleph2 keyframe union field names.
|
||||
KEYFRAME_MODE_SECONDS = "seconds" # absolute time, in seconds, from the start of the input video
|
||||
KEYFRAME_MODE_AT = "at" # fraction [0.0, 1.0] of the input video duration
|
||||
|
||||
# Prompt-image position modes (anchored to the OUTPUT video). Values match the Aleph2 position `type`.
|
||||
PROMPT_IMAGE_MODE_TIMESTAMP = "timestamp" # absolute time, in seconds, from the start of the output video
|
||||
PROMPT_IMAGE_MODE_POSITION = "position" # fraction [0.0, 1.0] of the output video duration
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeItem:
|
||||
"""A guidance image anchored to a point of the INPUT video (one Aleph2 ``keyframe``)."""
|
||||
|
||||
def __init__(self, image, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # KEYFRAME_MODE_SECONDS | KEYFRAME_MODE_AT
|
||||
self.value = value
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeChain:
|
||||
"""An ordered collection of keyframes, built by chaining Runway Aleph2 Keyframe nodes."""
|
||||
|
||||
def __init__(self):
|
||||
self.items: list[RunwayAleph2KeyframeItem] = []
|
||||
|
||||
def add(self, item: RunwayAleph2KeyframeItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "RunwayAleph2KeyframeChain":
|
||||
c = RunwayAleph2KeyframeChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageItem:
|
||||
"""A guidance image anchored to a point of the OUTPUT video (one Aleph2 ``promptImage``)."""
|
||||
|
||||
def __init__(self, image, mode: str, value: float):
|
||||
self.image = image
|
||||
self.mode = mode # PROMPT_IMAGE_MODE_TIMESTAMP | PROMPT_IMAGE_MODE_POSITION
|
||||
self.value = value
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageChain:
|
||||
"""An ordered collection of prompt images, built by chaining Runway Aleph2 Prompt Image nodes."""
|
||||
|
||||
def __init__(self):
|
||||
self.items: list[RunwayAleph2PromptImageItem] = []
|
||||
|
||||
def add(self, item: RunwayAleph2PromptImageItem) -> None:
|
||||
self.items.append(item)
|
||||
|
||||
def clone(self) -> "RunwayAleph2PromptImageChain":
|
||||
c = RunwayAleph2PromptImageChain()
|
||||
c.items = list(self.items)
|
||||
return c
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeSeconds(BaseModel):
|
||||
seconds: float = Field(
|
||||
...,
|
||||
description="Absolute timestamp in seconds from the start of the input video when this guidance image should apply.",
|
||||
ge=0.0,
|
||||
)
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeAt(BaseModel):
|
||||
at: float = Field(
|
||||
...,
|
||||
description="Position as a fraction [0.0, 1.0] of the input video duration.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2TimestampPosition(BaseModel):
|
||||
type: str = Field(default="timestamp")
|
||||
timestampSeconds: float = Field(
|
||||
...,
|
||||
description="Absolute timestamp in seconds from the start of the output video.",
|
||||
ge=0.0,
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2RelativePosition(BaseModel):
|
||||
type: str = Field(default="position")
|
||||
positionPercentage: float = Field(
|
||||
...,
|
||||
description="Position as a fraction [0.0, 1.0] of the total output video duration.",
|
||||
ge=0.0,
|
||||
le=1.0,
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2PromptImage(BaseModel):
|
||||
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||
uri: str = Field(...)
|
||||
|
||||
|
||||
class RunwayAleph2ContentModeration(BaseModel):
|
||||
publicFigureThreshold: str = Field(
|
||||
...,
|
||||
description='When set to "low", the content moderation system is less strict about '
|
||||
'recognizable public figures. One of "auto" or "low".',
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2Request(BaseModel):
|
||||
model: str = Field(default="aleph2")
|
||||
promptText: str = Field(
|
||||
...,
|
||||
description="A non-empty string describing what should appear in the output.",
|
||||
min_length=1,
|
||||
max_length=1000,
|
||||
)
|
||||
videoUri: str = Field(...)
|
||||
seed: int = Field(..., description="Random seed for generation", ge=0, le=4294967295)
|
||||
contentModeration: RunwayAleph2ContentModeration = Field(...)
|
||||
keyframes: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] | None = Field(
|
||||
None,
|
||||
description="Timed guidance images placed at specific points in the input video. Up to 5.",
|
||||
)
|
||||
promptImage: list[RunwayAleph2PromptImage] | None = Field(
|
||||
None,
|
||||
description="Up to 5 image keyframes for guiding the edit at specific points in the output video.",
|
||||
)
|
||||
|
||||
|
||||
class RunwayAleph2Response(BaseModel):
|
||||
id: str | None = Field(None, description="Task ID")
|
||||
|
||||
@ -208,6 +208,10 @@ class TripoMultiviewToModelRequest(BaseModel):
|
||||
quad: bool | None = Field(False, description="Whether to apply quad to the generated model")
|
||||
|
||||
|
||||
class TripoTexturePrompt(BaseModel):
|
||||
text: str | None = Field(None, description="Text guidance for texture generation")
|
||||
|
||||
|
||||
class TripoTextureModelRequest(BaseModel):
|
||||
type: TripoTaskType = Field(TripoTaskType.TEXTURE_MODEL, description="Type of task")
|
||||
original_model_task_id: str = Field(..., description="The task ID of the original model")
|
||||
@ -219,6 +223,11 @@ class TripoTextureModelRequest(BaseModel):
|
||||
texture_alignment: TripoTextureAlignment | None = Field(
|
||||
TripoTextureAlignment.ORIGINAL_IMAGE, description="The texture alignment method"
|
||||
)
|
||||
texture_prompt: TripoTexturePrompt | None = Field(
|
||||
None,
|
||||
description="Optional guidance for texturing. Required in practice for imported models, "
|
||||
"which carry no source image to infer texture from.",
|
||||
)
|
||||
|
||||
|
||||
class TripoRefineModelRequest(BaseModel):
|
||||
@ -307,6 +316,17 @@ class TripoP1MultiviewToModelRequest(TripoP1CommonRequest):
|
||||
orientation: str | None = None
|
||||
|
||||
|
||||
class TripoImportModelRequest(BaseModel):
|
||||
"""Request for the comfy-api composite import endpoint (/proxy/tripo/v2/openapi/import).
|
||||
|
||||
The model file is uploaded to ComfyUI API storage first; the backend downloads it from
|
||||
`url`, re-uploads it to Tripo's storage and creates the import_model task server-side.
|
||||
"""
|
||||
|
||||
url: str = Field(..., description="ComfyUI API storage download URL of the model file")
|
||||
format: str = Field(..., description='File format: "glb", "fbx", "obj" or "stl"')
|
||||
|
||||
|
||||
class TripoTaskOutput(BaseModel):
|
||||
model: str | None = Field(None, description="URL to the model")
|
||||
base_model: str | None = Field(None, description="URL to the base model")
|
||||
|
||||
@ -289,7 +289,7 @@ class BriaRemoveVideoBackground(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -357,7 +357,7 @@ class BriaVideoGreenScreen(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -433,7 +433,7 @@ class BriaVideoReplaceBackground(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -452,7 +452,10 @@ class BriaVideoReplaceBackground(IO.ComfyNode):
|
||||
validate_video_duration(background_video, max_duration=60.0)
|
||||
background_url = await upload_video_to_comfyapi(cls, background_video, wait_label="Uploading background")
|
||||
else:
|
||||
background_url = await upload_image_to_comfyapi(cls, background_image, wait_label="Uploading background")
|
||||
# Bria's replace_background 500s on RGBA, so drop the alpha channel before upload.
|
||||
background_url = await upload_image_to_comfyapi(
|
||||
cls, background_image[:, :, :, :3], wait_label="Uploading background"
|
||||
)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
ApiEndpoint(path="/proxy/bria/v2/video/edit/replace_background", method="POST"),
|
||||
@ -530,7 +533,7 @@ class BriaTransparentVideoBackground(IO.ComfyNode):
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd":0.14,"format":{"suffix":"/second"}}""",
|
||||
expr="""{"type":"usd","usd":0.0042,"format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@ -571,7 +574,7 @@ class BriaExtension(ComfyExtension):
|
||||
BriaRemoveImageBackground,
|
||||
BriaRemoveVideoBackground,
|
||||
BriaVideoGreenScreen,
|
||||
# BriaVideoReplaceBackground, # server returns Status 500 when we pass background video
|
||||
BriaVideoReplaceBackground,
|
||||
BriaTransparentVideoBackground,
|
||||
]
|
||||
|
||||
|
||||
@ -436,7 +436,7 @@ async def execute_text2video(
|
||||
negative_prompt=negative_prompt if negative_prompt else None,
|
||||
duration=KlingVideoGenDuration(duration),
|
||||
mode=KlingVideoGenMode(model_mode),
|
||||
model_name=KlingVideoGenModelName(model_name),
|
||||
model_name=model_name,
|
||||
cfg_scale=cfg_scale,
|
||||
aspect_ratio=KlingVideoGenAspectRatio(aspect_ratio),
|
||||
camera_control=camera_control,
|
||||
|
||||
@ -9,6 +9,7 @@ from PIL import Image
|
||||
from typing_extensions import override
|
||||
|
||||
import folder_paths
|
||||
from comfy.utils import common_upscale
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api_nodes.apis.openai import (
|
||||
InputFileContent,
|
||||
@ -62,7 +63,8 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
|
||||
timeout: Request timeout in seconds. Defaults to None (no timeout).
|
||||
|
||||
Returns:
|
||||
A torch.Tensor representing the image (1, H, W, C).
|
||||
A torch.Tensor of shape (N, H, W, C) with all returned images; images whose
|
||||
dimensions differ from the first image's are resized to match it.
|
||||
|
||||
Raises:
|
||||
ValueError: If the response is not valid.
|
||||
@ -89,6 +91,14 @@ async def validate_and_cast_response(response, timeout: int = None) -> torch.Ten
|
||||
arr = np.asarray(pil_img).astype(np.float32) / 255.0
|
||||
image_tensors.append(torch.from_numpy(arr))
|
||||
|
||||
# With size="auto" the API can return images whose dimensions differ by a few pixels within a single response
|
||||
# resize them to the first image's dimensions so they can be stacked into one batch.
|
||||
ref_h, ref_w = image_tensors[0].shape[:2]
|
||||
for i, t in enumerate(image_tensors):
|
||||
if t.shape[:2] != (ref_h, ref_w):
|
||||
samples = t.unsqueeze(0).movedim(-1, 1)
|
||||
samples = common_upscale(samples, ref_w, ref_h, "bilinear", "center")
|
||||
image_tensors[i] = samples.movedim(1, -1).squeeze(0)
|
||||
return torch.stack(image_tensors, dim=0)
|
||||
|
||||
|
||||
|
||||
@ -30,13 +30,33 @@ from comfy_api_nodes.apis.runway import (
|
||||
Model4,
|
||||
ReferenceImage,
|
||||
RunwayTextToImageAspectRatioEnum,
|
||||
RunwayAleph2IO,
|
||||
RunwayAleph2KeyframeChain,
|
||||
RunwayAleph2KeyframeItem,
|
||||
RunwayAleph2PromptImageChain,
|
||||
RunwayAleph2PromptImageItem,
|
||||
RunwayAleph2Request,
|
||||
RunwayAleph2Response,
|
||||
RunwayAleph2KeyframeSeconds,
|
||||
RunwayAleph2KeyframeAt,
|
||||
RunwayAleph2PromptImage,
|
||||
RunwayAleph2TimestampPosition,
|
||||
RunwayAleph2RelativePosition,
|
||||
RunwayAleph2ContentModeration,
|
||||
KEYFRAME_MODE_SECONDS,
|
||||
KEYFRAME_MODE_AT,
|
||||
PROMPT_IMAGE_MODE_TIMESTAMP,
|
||||
PROMPT_IMAGE_MODE_POSITION,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
image_tensor_pair_to_batch,
|
||||
validate_string,
|
||||
validate_image_dimensions,
|
||||
validate_image_aspect_ratio,
|
||||
validate_video_duration,
|
||||
upload_images_to_comfyapi,
|
||||
upload_image_to_comfyapi,
|
||||
upload_video_to_comfyapi,
|
||||
download_url_to_video_output,
|
||||
download_url_to_image_tensor,
|
||||
ApiEndpoint,
|
||||
@ -45,6 +65,7 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
|
||||
PATH_IMAGE_TO_VIDEO = "/proxy/runway/image_to_video"
|
||||
PATH_VIDEO_TO_VIDEO = "/proxy/runway/video_to_video"
|
||||
PATH_TEXT_TO_IMAGE = "/proxy/runway/text_to_image"
|
||||
PATH_GET_TASK_STATUS = "/proxy/runway/tasks"
|
||||
|
||||
@ -53,12 +74,6 @@ AVERAGE_DURATION_FLF_SECONDS = 256
|
||||
AVERAGE_DURATION_T2I_SECONDS = 41
|
||||
|
||||
|
||||
class RunwayApiError(Exception):
|
||||
"""Base exception for Runway API errors."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class RunwayGen4TurboAspectRatio(str, Enum):
|
||||
"""Aspect ratios supported for Image to Video API when using gen4_turbo model."""
|
||||
|
||||
@ -84,14 +99,6 @@ def get_video_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
return None
|
||||
|
||||
|
||||
def extract_progress_from_task_status(
|
||||
response: TaskStatusResponse,
|
||||
) -> float | None:
|
||||
if hasattr(response, "progress") and response.progress is not None:
|
||||
return response.progress * 100
|
||||
return None
|
||||
|
||||
|
||||
def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
"""Returns the image URL from the task status response if it exists."""
|
||||
if hasattr(response, "output") and len(response.output) > 0:
|
||||
@ -102,14 +109,13 @@ def get_image_url_from_task_status(response: TaskStatusResponse) -> str | None:
|
||||
async def get_response(
|
||||
cls: type[IO.ComfyNode], task_id: str, estimated_duration: int | None = None
|
||||
) -> TaskStatusResponse:
|
||||
"""Poll the task status until it is finished then get the response."""
|
||||
return await poll_op(
|
||||
cls,
|
||||
ApiEndpoint(path=f"{PATH_GET_TASK_STATUS}/{task_id}"),
|
||||
response_model=TaskStatusResponse,
|
||||
status_extractor=lambda r: r.status.value,
|
||||
status_extractor=lambda r: r.status,
|
||||
estimated_duration=estimated_duration,
|
||||
progress_extractor=extract_progress_from_task_status,
|
||||
progress_extractor=lambda r: r.progress * 100 if r.progress is not None else None,
|
||||
)
|
||||
|
||||
|
||||
@ -127,7 +133,7 @@ async def generate_video(
|
||||
|
||||
final_response = await get_response(cls, initial_response.id, estimated_duration)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no video data found in response.")
|
||||
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
video_url = get_video_url_from_task_status(final_response)
|
||||
return await download_url_to_video_output(video_url)
|
||||
@ -410,7 +416,7 @@ class RunwayFirstLastFrameNode(IO.ComfyNode):
|
||||
mime_type="image/png",
|
||||
)
|
||||
if len(download_urls) != 2:
|
||||
raise RunwayApiError("Failed to upload one or more images to comfy api.")
|
||||
raise ValueError("Failed to upload one or more images to comfy api.")
|
||||
|
||||
return IO.NodeOutput(
|
||||
await generate_video(
|
||||
@ -514,11 +520,321 @@ class RunwayTextToImageNode(IO.ComfyNode):
|
||||
estimated_duration=AVERAGE_DURATION_T2I_SECONDS,
|
||||
)
|
||||
if not final_response.output:
|
||||
raise RunwayApiError("Runway task succeeded but no image data found in response.")
|
||||
raise ValueError("Runway task succeeded but no image data found in response.")
|
||||
|
||||
return IO.NodeOutput(await download_url_to_image_tensor(get_image_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
_TIMING_ABSOLUTE = "Absolute time (seconds)"
|
||||
_TIMING_FRACTION = "Fraction of duration (0.0-1.0)"
|
||||
|
||||
|
||||
class RunwayAleph2KeyframeNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2KeyframeNode",
|
||||
display_name="Runway Aleph2 Keyframe",
|
||||
category="partner/video/Runway",
|
||||
description="Anchor a guidance image to a moment of the input (source) video, so Aleph2 "
|
||||
"steers the edit at that point of your footage. Connect this to the 'keyframes' input of "
|
||||
"the Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||
"'keyframes' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The guidance image to apply at the chosen moment of the input video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"timing",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_ABSOLUTE,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=30.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from start of the input video where this image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_FRACTION,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the input video this image applies, "
|
||||
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the input video's timeline.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Optional earlier keyframes to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(RunwayAleph2IO.KEYFRAME).Output(display_name="keyframes")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
timing: dict,
|
||||
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = keyframes.clone() if keyframes is not None else RunwayAleph2KeyframeChain()
|
||||
if timing["timing"] == _TIMING_ABSOLUTE:
|
||||
mode, value = KEYFRAME_MODE_SECONDS, float(timing["seconds"])
|
||||
else:
|
||||
mode, value = KEYFRAME_MODE_AT, float(timing["fraction"])
|
||||
chain.add(RunwayAleph2KeyframeItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class RunwayAleph2PromptImageNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2PromptImageNode",
|
||||
display_name="Runway Aleph2 Prompt Image",
|
||||
category="partner/video/Runway",
|
||||
description="Anchor a guidance image to a moment of the output (result) video, to guide what "
|
||||
"the edited video looks like at that point. Connect this to the 'prompt_images' input of the "
|
||||
"Runway Aleph2 Video to Video node; chain several together (up to 5) via the optional "
|
||||
"'prompt_images' input below.",
|
||||
inputs=[
|
||||
IO.Image.Input(
|
||||
"image",
|
||||
tooltip="The guidance image to place at the chosen moment of the output video.",
|
||||
),
|
||||
IO.DynamicCombo.Input(
|
||||
"position",
|
||||
options=[
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_ABSOLUTE,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"seconds",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=30.0,
|
||||
step=0.1,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Time in seconds from start of the output video where this image applies.",
|
||||
),
|
||||
],
|
||||
),
|
||||
IO.DynamicCombo.Option(
|
||||
_TIMING_FRACTION,
|
||||
[
|
||||
IO.Float.Input(
|
||||
"fraction",
|
||||
default=0.0,
|
||||
min=0.0,
|
||||
max=1.0,
|
||||
step=0.01,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Where in the output video this image applies, "
|
||||
"as a fraction of its duration (0.0 = start, 1.0 = end).",
|
||||
),
|
||||
],
|
||||
),
|
||||
],
|
||||
tooltip="How to place this image on the output video's timeline.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||
"prompt_images",
|
||||
optional=True,
|
||||
tooltip="Optional earlier prompt images to chain with this one.",
|
||||
),
|
||||
],
|
||||
outputs=[IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Output(display_name="prompt_images")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(
|
||||
cls,
|
||||
image: Input.Image,
|
||||
position: dict,
|
||||
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
chain = prompt_images.clone() if prompt_images is not None else RunwayAleph2PromptImageChain()
|
||||
if position["position"] == _TIMING_ABSOLUTE:
|
||||
mode, value = PROMPT_IMAGE_MODE_TIMESTAMP, float(position["seconds"])
|
||||
else:
|
||||
mode, value = PROMPT_IMAGE_MODE_POSITION, float(position["fraction"])
|
||||
chain.add(RunwayAleph2PromptImageItem(image=image, mode=mode, value=value))
|
||||
return IO.NodeOutput(chain)
|
||||
|
||||
|
||||
class RunwayAleph2VideoToVideoNode(IO.ComfyNode):
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="RunwayAleph2VideoToVideoNode",
|
||||
display_name="Runway Aleph2 Video to Video",
|
||||
category="partner/video/Runway",
|
||||
description="Edit a video with a text prompt using Runway's Aleph2 model. Aleph2 transforms "
|
||||
"your footage (restyle, relight, add or remove elements, change the viewpoint) while keeping "
|
||||
"the original motion and timing; the output resolution matches the input video, which must be "
|
||||
"2-30 seconds at 30 fps or lower. Optionally steer the edit with either keyframes (anchored to "
|
||||
"the input video) or prompt images (anchored to the output video) - use one or the other, not both.",
|
||||
inputs=[
|
||||
IO.String.Input(
|
||||
"prompt",
|
||||
multiline=True,
|
||||
default="",
|
||||
tooltip="Describes what should appear in the output (1-1000 characters).",
|
||||
),
|
||||
IO.Video.Input(
|
||||
"video",
|
||||
tooltip="Input video to edit. Must be 2-30 seconds at 30 fps or lower.",
|
||||
),
|
||||
IO.Int.Input(
|
||||
"seed",
|
||||
default=0,
|
||||
min=0,
|
||||
max=4294967295,
|
||||
step=1,
|
||||
control_after_generate=True,
|
||||
display_mode=IO.NumberDisplay.number,
|
||||
tooltip="Random seed for generation",
|
||||
),
|
||||
IO.Combo.Input(
|
||||
"public_figure_threshold",
|
||||
options=["auto", "low"],
|
||||
default="low",
|
||||
tooltip="Content moderation for recognizable public figures.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.KEYFRAME).Input(
|
||||
"keyframes",
|
||||
optional=True,
|
||||
tooltip="Guidance images anchored to the input video, from Aleph2 Keyframe nodes (up to 5). "
|
||||
"Use keyframes or prompt images, not both.",
|
||||
),
|
||||
IO.Custom(RunwayAleph2IO.PROMPT_IMAGE).Input(
|
||||
"prompt_images",
|
||||
optional=True,
|
||||
tooltip="Guidance images anchored to the output video, from Aleph2 Prompt Image nodes (up to 5). "
|
||||
"Use keyframes or prompt images, not both.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Video.Output(),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"usd","usd": 0.4004, "format":{"suffix":"/second"}}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(
|
||||
cls,
|
||||
prompt: str,
|
||||
video: Input.Video,
|
||||
seed: int,
|
||||
public_figure_threshold: str = "low",
|
||||
keyframes: RunwayAleph2KeyframeChain | None = None,
|
||||
prompt_images: RunwayAleph2PromptImageChain | None = None,
|
||||
) -> IO.NodeOutput:
|
||||
validate_string(prompt, min_length=1, max_length=1000)
|
||||
validate_video_duration(
|
||||
video,
|
||||
min_duration=2.0,
|
||||
max_duration=30.0,
|
||||
)
|
||||
try:
|
||||
fps = float(video.get_frame_rate())
|
||||
except Exception:
|
||||
fps = None
|
||||
if fps is not None and fps > 30.0 + 0.01:
|
||||
raise ValueError(f"Input video frame rate ({fps:.2f} fps) exceeds Aleph2's maximum of 30 fps.")
|
||||
|
||||
if (keyframes and keyframes.items) and (prompt_images and prompt_images.items):
|
||||
raise ValueError("Aleph2 accepts either keyframes or prompt images, not both.")
|
||||
|
||||
video_duration: float | None = None
|
||||
try:
|
||||
video_duration = video.get_duration()
|
||||
except Exception:
|
||||
video_duration = None
|
||||
|
||||
def _check_seconds(value: float, label: str) -> None:
|
||||
if video_duration is not None and value > video_duration + 0.0001:
|
||||
raise ValueError(f"{label} {value:.2f}s exceeds the input video duration ({video_duration:.2f}s).")
|
||||
|
||||
video_url = await upload_video_to_comfyapi(cls, video)
|
||||
|
||||
keyframe_models: list[RunwayAleph2KeyframeSeconds | RunwayAleph2KeyframeAt] = []
|
||||
if keyframes is not None:
|
||||
if len(keyframes.items) > 5:
|
||||
raise ValueError("Aleph2 supports at most 5 keyframes.")
|
||||
for item in keyframes.items:
|
||||
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||
if item.mode == KEYFRAME_MODE_SECONDS:
|
||||
_check_seconds(item.value, "Keyframe timestamp")
|
||||
keyframe_models.append(RunwayAleph2KeyframeSeconds(seconds=item.value, uri=image_url))
|
||||
else:
|
||||
keyframe_models.append(RunwayAleph2KeyframeAt(at=item.value, uri=image_url))
|
||||
|
||||
prompt_image_models: list[RunwayAleph2PromptImage] = []
|
||||
if prompt_images is not None:
|
||||
if len(prompt_images.items) > 5:
|
||||
raise ValueError("Aleph2 supports at most 5 prompt images.")
|
||||
for item in prompt_images.items:
|
||||
image_url = await upload_image_to_comfyapi(cls, item.image, mime_type="image/png")
|
||||
position: RunwayAleph2TimestampPosition | RunwayAleph2RelativePosition
|
||||
if item.mode == PROMPT_IMAGE_MODE_TIMESTAMP:
|
||||
_check_seconds(item.value, "Prompt image timestamp")
|
||||
position = RunwayAleph2TimestampPosition(timestampSeconds=item.value)
|
||||
else:
|
||||
position = RunwayAleph2RelativePosition(positionPercentage=item.value)
|
||||
prompt_image_models.append(RunwayAleph2PromptImage(position=position, uri=image_url))
|
||||
|
||||
initial_response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path=PATH_VIDEO_TO_VIDEO, method="POST"),
|
||||
response_model=RunwayAleph2Response,
|
||||
data=RunwayAleph2Request(
|
||||
promptText=prompt,
|
||||
videoUri=video_url,
|
||||
seed=seed,
|
||||
contentModeration=RunwayAleph2ContentModeration(publicFigureThreshold=public_figure_threshold),
|
||||
keyframes=keyframe_models or None,
|
||||
promptImage=prompt_image_models or None,
|
||||
),
|
||||
)
|
||||
|
||||
final_response = await get_response(cls, initial_response.id)
|
||||
if not final_response.output:
|
||||
raise ValueError("Runway task succeeded but no video data found in response.")
|
||||
|
||||
return IO.NodeOutput(await download_url_to_video_output(get_video_url_from_task_status(final_response)))
|
||||
|
||||
|
||||
class RunwayExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[IO.ComfyNode]]:
|
||||
@ -527,6 +843,9 @@ class RunwayExtension(ComfyExtension):
|
||||
RunwayImageToVideoNodeGen3a,
|
||||
RunwayImageToVideoNodeGen4,
|
||||
RunwayTextToImageNode,
|
||||
RunwayAleph2VideoToVideoNode,
|
||||
RunwayAleph2KeyframeNode,
|
||||
RunwayAleph2PromptImageNode,
|
||||
]
|
||||
|
||||
|
||||
|
||||
@ -16,7 +16,7 @@ from comfy_api_nodes.util import (
|
||||
)
|
||||
from comfy_api_nodes.util._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_comfy_api_headers,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
)
|
||||
@ -174,8 +174,7 @@ async def _stream_sonilo_music(
|
||||
"""POST ``form`` to Sonilo, read the NDJSON stream, and return the first stream's audio bytes."""
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", endpoint.path.lstrip("/"))
|
||||
|
||||
headers: dict[str, str] = {}
|
||||
headers.update(get_auth_header(cls))
|
||||
headers = get_comfy_api_headers(cls)
|
||||
headers.update(endpoint.headers)
|
||||
|
||||
node_id = get_node_id(cls)
|
||||
|
||||
@ -1,6 +1,6 @@
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy_api.latest import IO, ComfyExtension, Input
|
||||
from comfy_api.latest import IO, ComfyExtension, Input, Types
|
||||
from comfy_api_nodes.apis.tripo import (
|
||||
TripoAnimateRetargetRequest,
|
||||
TripoAnimateRigRequest,
|
||||
@ -8,6 +8,7 @@ from comfy_api_nodes.apis.tripo import (
|
||||
TripoFileEmptyReference,
|
||||
TripoFileReference,
|
||||
TripoImageToModelRequest,
|
||||
TripoImportModelRequest,
|
||||
TripoModelVersion,
|
||||
TripoMultiviewToModelRequest,
|
||||
TripoOrientation,
|
||||
@ -21,6 +22,7 @@ from comfy_api_nodes.apis.tripo import (
|
||||
TripoTaskType,
|
||||
TripoTextToModelRequest,
|
||||
TripoTextureModelRequest,
|
||||
TripoTexturePrompt,
|
||||
TripoUrlReference,
|
||||
)
|
||||
from comfy_api_nodes.util import (
|
||||
@ -28,6 +30,7 @@ from comfy_api_nodes.util import (
|
||||
download_url_to_file_3d,
|
||||
poll_op,
|
||||
sync_op,
|
||||
upload_3d_model_to_comfyapi,
|
||||
upload_images_to_comfyapi,
|
||||
)
|
||||
|
||||
@ -538,6 +541,14 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
optional=True,
|
||||
advanced=True,
|
||||
),
|
||||
IO.String.Input(
|
||||
"texture_prompt",
|
||||
default="",
|
||||
multiline=True,
|
||||
optional=True,
|
||||
tooltip="Optional text guidance for texturing. Required in practice for imported "
|
||||
"models (Tripo: Import Model), which carry no source image to infer colors from.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.String.Output(display_name="model_file"), # for backward compatibility only
|
||||
@ -571,6 +582,7 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
texture_seed: int | None = None,
|
||||
texture_quality: str | None = None,
|
||||
texture_alignment: str | None = None,
|
||||
texture_prompt: str = "",
|
||||
) -> IO.NodeOutput:
|
||||
response = await sync_op(
|
||||
cls,
|
||||
@ -583,6 +595,7 @@ class TripoTextureNode(IO.ComfyNode):
|
||||
texture_seed=texture_seed,
|
||||
texture_quality=texture_quality,
|
||||
texture_alignment=texture_alignment,
|
||||
texture_prompt=TripoTexturePrompt(text=texture_prompt.strip()) if texture_prompt.strip() else None,
|
||||
),
|
||||
)
|
||||
return await poll_until_finished(cls, response, average_duration=80)
|
||||
@ -915,6 +928,90 @@ class TripoConversionNode(IO.ComfyNode):
|
||||
return await poll_until_finished(cls, response, average_duration=30)
|
||||
|
||||
|
||||
class TripoImportModelNode(IO.ComfyNode):
|
||||
"""Imports an external 3D model into Tripo, producing a MODEL_TASK_ID for post-processing nodes."""
|
||||
|
||||
SUPPORTED_FORMATS = ("glb", "fbx", "obj", "stl")
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return IO.Schema(
|
||||
node_id="TripoImportModelNode",
|
||||
display_name="Tripo: Import Model",
|
||||
category="partner/3d/Tripo",
|
||||
description="Import an external 3D model (e.g. from Rodin, Hunyuan3D or a local file) into Tripo "
|
||||
"to use it with Tripo's post-processing nodes: Texture, Rig, Convert. "
|
||||
"GLB is recommended: textures survive import only when embedded in the file. "
|
||||
"Note that texturing an imported model requires a texture prompt.",
|
||||
inputs=[
|
||||
IO.MultiType.Input(
|
||||
"model_3d",
|
||||
types=[IO.File3DGLB, IO.File3DFBX, IO.File3DOBJ, IO.File3DSTL, IO.File3DAny],
|
||||
tooltip="3D model to import (GLB / FBX / OBJ / STL, up to 150 MB). "
|
||||
"OBJ and STL files carry no embedded textures.",
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
IO.Custom("MODEL_TASK_ID").Output(display_name="model task_id"),
|
||||
],
|
||||
hidden=[
|
||||
IO.Hidden.auth_token_comfy_org,
|
||||
IO.Hidden.api_key_comfy_org,
|
||||
IO.Hidden.unique_id,
|
||||
],
|
||||
is_api_node=True,
|
||||
price_badge=IO.PriceBadge(
|
||||
expr="""{"type":"text","text":"Free"}""",
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
async def execute(cls, model_3d: Types.File3D) -> IO.NodeOutput:
|
||||
file_format = (model_3d.format or "").lstrip(".").lower()
|
||||
if file_format == "gltf":
|
||||
raise ValueError(
|
||||
"GLTF (.gltf) references external files and cannot be imported. Export a single-file GLB instead."
|
||||
)
|
||||
if file_format not in cls.SUPPORTED_FORMATS:
|
||||
raise ValueError(
|
||||
f"Unsupported 3D format '{file_format or 'unknown'}'. "
|
||||
f"Tripo import supports: {', '.join(f.upper() for f in cls.SUPPORTED_FORMATS)}."
|
||||
)
|
||||
size = len(model_3d.get_bytes())
|
||||
if size > 150 * 1024 * 1024:
|
||||
raise ValueError(f"Model file is {size / (1024 * 1024):.1f} MB; Tripo import allows up to 150 MB.")
|
||||
|
||||
url = await upload_3d_model_to_comfyapi(cls, model_3d, file_format)
|
||||
response = await sync_op(
|
||||
cls,
|
||||
endpoint=ApiEndpoint(path="/proxy/tripo/v2/openapi/import", method="POST"),
|
||||
response_model=TripoTaskResponse,
|
||||
data=TripoImportModelRequest(url=url, format=file_format),
|
||||
)
|
||||
if response.code != 0:
|
||||
raise RuntimeError(f"Failed to import model: {response.error}")
|
||||
|
||||
task_id = response.data.task_id
|
||||
response_poll = await poll_op(
|
||||
cls,
|
||||
poll_endpoint=ApiEndpoint(path=f"/proxy/tripo/v2/openapi/task/{task_id}"),
|
||||
response_model=TripoTaskResponse,
|
||||
failed_statuses=[
|
||||
TripoTaskStatus.FAILED,
|
||||
TripoTaskStatus.CANCELLED,
|
||||
TripoTaskStatus.UNKNOWN,
|
||||
TripoTaskStatus.BANNED,
|
||||
TripoTaskStatus.EXPIRED,
|
||||
],
|
||||
status_extractor=lambda x: x.data.status,
|
||||
progress_extractor=lambda x: x.data.progress,
|
||||
estimated_duration=10,
|
||||
)
|
||||
if response_poll.data.status != TripoTaskStatus.SUCCESS:
|
||||
raise RuntimeError(f"Failed to import model: {response_poll}")
|
||||
return IO.NodeOutput(task_id)
|
||||
|
||||
|
||||
def _p1_price_expr(*, geometry_credits: int, textured_credits: int, detailed_credits: int) -> str:
|
||||
return (
|
||||
"("
|
||||
@ -1292,6 +1389,7 @@ class TripoExtension(ComfyExtension):
|
||||
TripoP1TextToModelNode,
|
||||
TripoP1ImageToModelNode,
|
||||
TripoP1MultiviewToModelNode,
|
||||
TripoImportModelNode,
|
||||
TripoTextureNode,
|
||||
TripoRefineNode,
|
||||
TripoRigNode,
|
||||
|
||||
@ -9,6 +9,7 @@ from io import BytesIO
|
||||
from yarl import URL
|
||||
|
||||
from comfy.cli_args import args
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
from comfy.model_management import processing_interrupted
|
||||
from comfy_api.latest import IO
|
||||
|
||||
@ -35,6 +36,30 @@ def get_auth_header(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||
return {}
|
||||
|
||||
|
||||
def get_usage_source(node_cls: type[IO.ComfyNode]) -> str:
|
||||
"""Source of the prompt that triggered this API node.
|
||||
|
||||
Defaults to "comfyui-api" when the submitting client didn't identify itself,
|
||||
i.e. a direct API call to this server.
|
||||
"""
|
||||
return node_cls.hidden.comfy_usage_source or "comfyui-api"
|
||||
|
||||
|
||||
def get_comfy_api_headers(node_cls: type[IO.ComfyNode]) -> dict[str, str]:
|
||||
"""Common headers (auth, deploy environment, usage source) for Comfy API requests.
|
||||
|
||||
Centralizes the shared header set so every Comfy API request sends a consistent
|
||||
set and new shared headers only need to be added in one place. Intended for
|
||||
relative/cloud URLs resolved against ``default_base_url()``; because the result
|
||||
includes auth, callers must not attach it to arbitrary absolute/presigned URLs.
|
||||
"""
|
||||
return {
|
||||
**get_auth_header(node_cls),
|
||||
"Comfy-Env": get_deploy_environment(),
|
||||
"Comfy-Usage-Source": get_usage_source(node_cls),
|
||||
}
|
||||
|
||||
|
||||
def default_base_url() -> str:
|
||||
return getattr(args, "comfy_api_base", "https://api.comfy.org")
|
||||
|
||||
|
||||
@ -19,12 +19,10 @@ from comfy import utils
|
||||
from comfy_api.latest import IO
|
||||
from server import PromptServer
|
||||
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_comfy_api_headers,
|
||||
get_node_id,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
@ -645,8 +643,7 @@ async def _request_base(cfg: _RequestConfig, expect_binary: bool):
|
||||
|
||||
payload_headers = {"Accept": "*/*"} if expect_binary else {"Accept": "application/json"}
|
||||
if not parsed_url.scheme and not parsed_url.netloc: # is URL relative?
|
||||
payload_headers.update(get_auth_header(cfg.node_cls))
|
||||
payload_headers["Comfy-Env"] = get_deploy_environment()
|
||||
payload_headers.update(get_comfy_api_headers(cfg.node_cls))
|
||||
if cfg.endpoint.headers:
|
||||
payload_headers.update(cfg.endpoint.headers)
|
||||
|
||||
|
||||
@ -17,7 +17,7 @@ from folder_paths import get_output_directory
|
||||
from . import request_logger
|
||||
from ._helpers import (
|
||||
default_base_url,
|
||||
get_auth_header,
|
||||
get_comfy_api_headers,
|
||||
is_processing_interrupted,
|
||||
sleep_with_interrupt,
|
||||
to_aiohttp_url,
|
||||
@ -64,7 +64,7 @@ async def download_url_to_bytesio(
|
||||
if cls is None:
|
||||
raise ValueError("For relative 'cloud' paths, the `cls` parameter is required.")
|
||||
url = urljoin(default_base_url().rstrip("/") + "/", url.lstrip("/"))
|
||||
headers = get_auth_header(cls)
|
||||
headers = get_comfy_api_headers(cls)
|
||||
|
||||
while True:
|
||||
attempt += 1
|
||||
|
||||
66
comfy_execution/asset_enrichment.py
Normal file
66
comfy_execution/asset_enrichment.py
Normal file
@ -0,0 +1,66 @@
|
||||
"""Enrich executed-node output entries with asset id."""
|
||||
import logging
|
||||
import os
|
||||
|
||||
|
||||
def enrich_output_with_assets(output_ui: dict) -> dict:
|
||||
"""Register file-type output entries as assets and inject their ``id``.
|
||||
|
||||
Runs at output-processing time, once per produced output, when
|
||||
--enable-assets is set. Returns a new dict; entries without a resolvable
|
||||
on-disk file path are left unchanged. Errors are caught per-entry so a
|
||||
failure never blocks execution or the other entries.
|
||||
"""
|
||||
from comfy.cli_args import args
|
||||
if not args.enable_assets:
|
||||
return output_ui
|
||||
|
||||
import folder_paths
|
||||
from app.assets.services.ingest import register_file_in_place, DependencyMissingError
|
||||
|
||||
enriched = {}
|
||||
for key, entries in output_ui.items():
|
||||
if not isinstance(entries, list):
|
||||
enriched[key] = entries
|
||||
continue
|
||||
new_entries = []
|
||||
for entry in entries:
|
||||
if not isinstance(entry, dict) or "filename" not in entry or "type" not in entry:
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
try:
|
||||
base = folder_paths.get_directory_by_type(entry["type"])
|
||||
if base is None:
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
base_abs = os.path.abspath(base)
|
||||
abs_path = os.path.abspath(os.path.join(base_abs, entry.get("subfolder") or "", entry["filename"]))
|
||||
try:
|
||||
if os.path.commonpath([base_abs, abs_path]) != base_abs:
|
||||
raise ValueError("escapes base")
|
||||
except ValueError:
|
||||
logging.warning("Asset enrichment skipped (path escapes base): %s", entry.get("filename"))
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
if not os.path.isfile(abs_path):
|
||||
new_entries.append(entry)
|
||||
continue
|
||||
|
||||
# Register unconditionally: the file was just produced, and
|
||||
# register_file_in_place re-hashes so an overwritten path can
|
||||
# never carry a stale id.
|
||||
result = register_file_in_place(
|
||||
abs_path=abs_path,
|
||||
name=entry["filename"],
|
||||
tags=[entry["type"]],
|
||||
)
|
||||
|
||||
entry = dict(entry)
|
||||
entry["id"] = result.ref.id
|
||||
except DependencyMissingError:
|
||||
logging.warning("Asset enrichment skipped (blake3 not available): %s", entry.get("filename"))
|
||||
except Exception:
|
||||
logging.warning("Failed to enrich output entry with asset id: %s", entry.get("filename"), exc_info=True)
|
||||
new_entries.append(entry)
|
||||
enriched[key] = new_entries
|
||||
return enriched
|
||||
@ -3,6 +3,7 @@ Job utilities for the /api/jobs endpoint.
|
||||
Provides normalization and helper functions for job status tracking.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from typing import Optional
|
||||
|
||||
from comfy_api.internal import prune_dict
|
||||
@ -19,6 +20,25 @@ class JobStatus:
|
||||
ALL = [PENDING, IN_PROGRESS, COMPLETED, FAILED, CANCELLED]
|
||||
|
||||
|
||||
def validate_job_id(value) -> str:
|
||||
"""Validate a client-supplied job (prompt) id.
|
||||
|
||||
Job ids must be UUIDs in the canonical lowercase hyphenated form. The id
|
||||
is stored and compared verbatim everywhere downstream — history keys,
|
||||
websocket events, and /interrupt matching — so accepting another spelling
|
||||
would silently rewrite the client's id and then miss every exact-match
|
||||
lookup. Rejecting loudly beats that.
|
||||
|
||||
Returns the id unchanged. Raises ValueError when the value is not a
|
||||
string in canonical UUID form.
|
||||
"""
|
||||
if not isinstance(value, str):
|
||||
raise ValueError(f"job id must be a string, got {type(value).__name__}")
|
||||
if str(uuid.UUID(value)) != value:
|
||||
raise ValueError("job id must be a UUID in canonical lowercase hyphenated form")
|
||||
return value
|
||||
|
||||
|
||||
# Media types that can be previewed in the frontend
|
||||
PREVIEWABLE_MEDIA_TYPES = frozenset({'images', 'video', 'audio', '3d', 'text'})
|
||||
|
||||
|
||||
115
comfy_extras/nodes_bernini.py
Normal file
115
comfy_extras/nodes_bernini.py
Normal file
@ -0,0 +1,115 @@
|
||||
import torch
|
||||
from typing_extensions import override
|
||||
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
import node_helpers
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
|
||||
|
||||
def _resize_long_edge(image, max_size, stride=16):
|
||||
"""Resize (preserve aspect) so the long edge <= max_size, then snap each side to `stride`"""
|
||||
h, w = image.shape[1], image.shape[2]
|
||||
scale = min(max_size / max(h, w), 1.0)
|
||||
nh = max(stride, round(h * scale / stride) * stride)
|
||||
nw = max(stride, round(w * scale / stride) * stride)
|
||||
return comfy.utils.common_upscale(image[:, :, :, :3].movedim(-1, 1), nw, nh, "area", "disabled").movedim(1, -1)
|
||||
|
||||
|
||||
class BerniniConditioning(io.ComfyNode):
|
||||
"""Bernini in-context conditioning for a Wan2.2-A14B model.
|
||||
|
||||
Attaches the VAE-encoded source video / reference images to the conditioning
|
||||
source video first, then each reference image
|
||||
|
||||
The task is inferred from which inputs are connected:
|
||||
(nothing) -> t2v (text-to-video)
|
||||
source_video -> v2v (video-to-video)
|
||||
source_video + ref_images -> rv2v (reference-guided video editing)
|
||||
ref_images only -> r2v (reference-to-video)
|
||||
source_video + ref_video -> ads2v (insert image/video into video)
|
||||
|
||||
source_video is the edit base / canvas (resized to width x height).
|
||||
reference_video is moving content to composite in.
|
||||
Streams are ordered source_video, reference_video, then reference_images -> source_id (1, 2, 3, ...).
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="BerniniConditioning",
|
||||
display_name="Bernini Conditioning",
|
||||
category="conditioning/video_models",
|
||||
description="Conditioning node for Bernini in-context video/image conditioning. It can be used for the following tasks: t2v (text-to-video), v2v (video-to-video), rv2v (reference-guided video editing), r2v (reference-to-video), ads2v (insert image/video into video)."
|
||||
"Reference images injected as in-context tokens (r2v, rv2v) are encoded independently at their own native aspect ratio (long edge capped at ref_max_size)",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=832, min=16, max=8192, step=16),
|
||||
io.Int.Input("height", default=480, min=16, max=8192, step=16),
|
||||
io.Int.Input("length", default=81, min=1, max=8192, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("source_video", optional=True, tooltip=(
|
||||
"Source video to edit or restyle (v2v, rv2v). Resized to width/height and trimmed to length.")),
|
||||
io.Image.Input("reference_video", optional=True, tooltip=(
|
||||
"Video to insert into the source video (ads2v).")),
|
||||
io.Autogrow.Input("reference_images", optional=True,
|
||||
template=io.Autogrow.TemplatePrefix(
|
||||
input=io.Image.Input("reference_image", tooltip=(
|
||||
"Reference image injected as an in-context token (r2v, rv2v).")),
|
||||
prefix="reference_image_", min=0, max=8)),
|
||||
io.Int.Input("ref_max_size", default=848, min=16, max=8192, step=16, optional=True, tooltip=(
|
||||
"Max size for the long edge of reference_video and reference_images. Resized with preserved aspect ratio and snapped to 16px.")),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size,
|
||||
source_video=None, reference_video=None, reference_images=None, ref_max_size=848) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8],
|
||||
device=comfy.model_management.intermediate_device())
|
||||
|
||||
# source_video (1), reference_video (2), reference_images (3, 4, ...).
|
||||
context = []
|
||||
if source_video is not None:
|
||||
vid = comfy.utils.common_upscale(source_video[:length, :, :, :3].movedim(-1, 1), width, height, "area", "center").movedim(1, -1)
|
||||
context.append(vae.encode(vid[:, :, :, :3]))
|
||||
|
||||
if reference_video is not None:
|
||||
ref_vid = _resize_long_edge(reference_video[:length], ref_max_size) # moving content, native aspect
|
||||
context.append(vae.encode(ref_vid[:, :, :, :3]))
|
||||
|
||||
# reference_images is an autogrow dict {reference_image_0: IMAGE, ...}; each slot is a
|
||||
# separate stream at its own native aspect (a multi-image batch in one slot -> one stream per frame).
|
||||
if reference_images:
|
||||
for name in sorted(reference_images):
|
||||
imgs = reference_images[name]
|
||||
if imgs is None:
|
||||
continue
|
||||
for i in range(imgs.shape[0]):
|
||||
img = _resize_long_edge(imgs[i:i + 1], ref_max_size) # native aspect per ref
|
||||
context.append(vae.encode(img[:, :, :, :3]))
|
||||
|
||||
if context:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"context_latents": context})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"context_latents": context})
|
||||
|
||||
return io.NodeOutput(positive, negative, {"samples": latent})
|
||||
|
||||
|
||||
class BerniniExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
BerniniConditioning,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> BerniniExtension:
|
||||
return BerniniExtension()
|
||||
@ -36,15 +36,15 @@ class RemoveBackground(IO.ComfyNode):
|
||||
category="image/background removal",
|
||||
description="Generates a foreground mask to remove the background from an image using a background removal model.",
|
||||
inputs=[
|
||||
IO.Image.Input("image", tooltip="Input image to remove the background from"),
|
||||
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask")
|
||||
IO.BackgroundRemoval.Input("bg_removal_model", tooltip="Background removal model used to generate the mask"),
|
||||
IO.Image.Input("image", tooltip="Input image to remove the background from")
|
||||
],
|
||||
outputs=[
|
||||
IO.Mask.Output("mask", tooltip="Generated foreground mask")
|
||||
]
|
||||
)
|
||||
@classmethod
|
||||
def execute(cls, image, bg_removal_model):
|
||||
def execute(cls, bg_removal_model, image):
|
||||
mask = bg_removal_model.encode_image(image)
|
||||
return IO.NodeOutput(mask)
|
||||
|
||||
|
||||
681
comfy_extras/nodes_depth_anything_3.py
Normal file
681
comfy_extras/nodes_depth_anything_3.py
Normal file
@ -0,0 +1,681 @@
|
||||
"""ComfyUI nodes for Depth Anything 3.
|
||||
Model capability matrix:
|
||||
|
||||
Variant head_type has_sky has_conf cam_dec
|
||||
DA3-Small dualdpt False True yes
|
||||
DA3-Base dualdpt False True yes
|
||||
DA3-Mono-Large dpt True False no
|
||||
DA3-Metric-Large dpt True False no (raw output is metres)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
|
||||
import comfy.model_management as mm
|
||||
import comfy.sd
|
||||
import folder_paths
|
||||
from comfy.ldm.colormap import turbo as _turbo
|
||||
from comfy.ldm.depth_anything_3 import preprocess as da3_preprocess
|
||||
from comfy_api.latest import ComfyExtension, Types, io
|
||||
from comfy.ldm.moge.geometry import triangulate_grid_mesh
|
||||
|
||||
DA3ModelType = io.Custom("DA3_MODEL")
|
||||
DA3Geometry = io.Custom("DA3_GEOMETRY")
|
||||
DA3PointCloud = io.Custom("DA3_POINT_CLOUD")
|
||||
|
||||
# DA3_GEOMETRY is a dict with these optional keys (absent when the upstream model didn't produce them):
|
||||
#
|
||||
# Per-frame tensors - B = batch size in mono mode; B = S (number of views) in multi-view mode.
|
||||
# "depth": torch.Tensor (B, H, W) -- raw model depth (always present; matches MoGe convention)
|
||||
# "image": torch.Tensor (B, H, W, 3) -- source image in [0, 1], CPU (always present)
|
||||
# "mode": str -- "mono" or "multiview" (always present)
|
||||
# "sky": torch.Tensor (B, H, W) -- sky probability in [0, 1] (Mono/Metric variants only)
|
||||
# "confidence": torch.Tensor (B, H, W) -- raw model confidence output (Small/Base variants only)
|
||||
#
|
||||
# Multi-view only - S = number of views; the leading 1 is the scene dimension from the model.
|
||||
# "extrinsics": torch.Tensor (1, S, 3, 4) -- world-to-camera [R|t] matrices
|
||||
# "intrinsics": torch.Tensor (1, S, 3, 3) -- pixel-space intrinsics
|
||||
#
|
||||
# DA3_POINT_CLOUD is a dict:
|
||||
# "points": torch.Tensor (N, 3) -- 3-D coords in glTF convention (Y-up, Z-back)
|
||||
# "colors": torch.Tensor (N, 3) -- RGB in [0, 1], or None
|
||||
# "confidence": torch.Tensor (N,) -- raw confidence per point, or None
|
||||
|
||||
|
||||
def _da3_unproject(depth: torch.Tensor, K: torch.Tensor) -> torch.Tensor:
|
||||
"""Pixel-space K⁻¹ unprojection: (H,W) depth → (H,W,3) point map in OpenCV space."""
|
||||
H, W = depth.shape
|
||||
u = torch.arange(W, dtype=torch.float32, device=depth.device)
|
||||
v = torch.arange(H, dtype=torch.float32, device=depth.device)
|
||||
u, v = torch.meshgrid(u, v, indexing='xy') # both (H, W)
|
||||
pix = torch.stack([u, v, torch.ones_like(u)], dim=-1) # (H, W, 3)
|
||||
rays = torch.einsum('ij,hwj->hwi', torch.linalg.inv(K.to(depth.device)), pix)
|
||||
return rays * depth.unsqueeze(-1) # (H, W, 3)
|
||||
|
||||
|
||||
def _da3_default_K(H: int, W: int) -> torch.Tensor:
|
||||
"""Fallback ~60° FOV pinhole K for mono-mode DA3 (no intrinsics in geometry)."""
|
||||
fx = fy = float(W) * 0.7
|
||||
return torch.tensor([[fx, 0.0, (W - 1) / 2.0],
|
||||
[0.0, fy, (H - 1) / 2.0],
|
||||
[0.0, 0.0, 1.0]], dtype=torch.float32)
|
||||
|
||||
|
||||
def _da3_get_K(geometry: dict, b: int, H: int, W: int) -> torch.Tensor:
|
||||
"""Return pixel-space K for batch element b, falling back to a default estimate."""
|
||||
if "intrinsics" in geometry:
|
||||
# shape (1, S, 3, 3) - leading scene dimension from the multiview head
|
||||
return geometry["intrinsics"][0, b].float()
|
||||
logging.getLogger("comfy").warning(
|
||||
"DA3_GEOMETRY has no intrinsics (mono-mode model). "
|
||||
"Using a ~60° FOV estimate; 3-D reconstruction may be inaccurate."
|
||||
)
|
||||
return _da3_default_K(H, W)
|
||||
|
||||
|
||||
def _da3_get_extrinsic(geometry: dict, b: int) -> torch.Tensor | None:
|
||||
"""Return the world-to-camera extrinsic for batch element b, or None in mono mode.
|
||||
|
||||
The model outputs (1, S, 3, 4) [R|t] matrices; the fallback identity is (4, 4).
|
||||
_da3_apply_extrinsic handles both shapes via [:3, :3] / [:3, 3] slicing.
|
||||
"""
|
||||
if "extrinsics" not in geometry:
|
||||
return None
|
||||
return geometry["extrinsics"][0, b].float()
|
||||
|
||||
|
||||
def _da3_apply_extrinsic(points_cam: torch.Tensor, E: torch.Tensor) -> torch.Tensor:
|
||||
"""Transform (H,W,3) OpenCV camera-space points to world space."""
|
||||
E = E.to(points_cam.device).float()
|
||||
if not torch.isfinite(E).all():
|
||||
logging.getLogger("comfy").warning(
|
||||
"DA3 extrinsic matrix contains non-finite values (pose estimation may have failed). "
|
||||
"Falling back to camera-space coordinates."
|
||||
)
|
||||
return points_cam
|
||||
H, W, _ = points_cam.shape
|
||||
R = E[:3, :3] # (3, 3) rotation
|
||||
t = E[:3, 3] # (3,) translation
|
||||
R_inv = R.T # rotation inverse = transpose for orthogonal R
|
||||
t_inv = -(R_inv @ t) # (3,)
|
||||
pts = points_cam.reshape(-1, 3) # (N, 3)
|
||||
pts_world = pts @ R_inv.T + t_inv # (N, 3)
|
||||
return pts_world.reshape(H, W, 3)
|
||||
|
||||
|
||||
def _normalize_confidence(conf: torch.Tensor) -> torch.Tensor:
|
||||
"""Map raw confidence to [0, 1] per image."""
|
||||
B = conf.shape[0]
|
||||
out = []
|
||||
for i in range(B):
|
||||
c = conf[i]
|
||||
c_min, c_max = c.min(), c.max()
|
||||
out.append((c - c_min) / (c_max - c_min) if c_max > c_min else torch.ones_like(c))
|
||||
return torch.stack(out, dim=0)
|
||||
|
||||
|
||||
def _da3_build_mask(geometry: dict, b: int, H: int, W: int, confidence_threshold: float, use_sky_mask: bool) -> torch.Tensor:
|
||||
"""Build (H,W) bool keep-mask from sky probability and confidence."""
|
||||
mask = torch.ones(H, W, dtype=torch.bool)
|
||||
if use_sky_mask and "sky" in geometry:
|
||||
mask = mask & (geometry["sky"][b] < 0.5)
|
||||
if "confidence" in geometry and confidence_threshold > 0.0:
|
||||
conf_norm = _normalize_confidence(geometry["confidence"][b:b + 1])[0]
|
||||
mask = mask & (conf_norm >= confidence_threshold)
|
||||
return mask
|
||||
|
||||
|
||||
class LoadDA3Model(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="LoadDA3Model",
|
||||
display_name="Load Depth Anything 3",
|
||||
category="model/loaders",
|
||||
inputs=[
|
||||
io.Combo.Input(
|
||||
"model_name",
|
||||
options=folder_paths.get_filename_list("geometry_estimation"),
|
||||
),
|
||||
io.Combo.Input(
|
||||
"weight_dtype",
|
||||
options=["default", "fp16", "bf16", "fp32"],
|
||||
default="default",
|
||||
),
|
||||
],
|
||||
outputs=[DA3ModelType.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, model_name, weight_dtype) -> io.NodeOutput:
|
||||
model_options = {}
|
||||
if weight_dtype == "fp16":
|
||||
model_options["dtype"] = torch.float16
|
||||
elif weight_dtype == "bf16":
|
||||
model_options["dtype"] = torch.bfloat16
|
||||
elif weight_dtype == "fp32":
|
||||
model_options["dtype"] = torch.float32
|
||||
|
||||
path = folder_paths.get_full_path_or_raise("geometry_estimation", model_name)
|
||||
model = comfy.sd.load_diffusion_model(path, model_options=model_options)
|
||||
return io.NodeOutput(model)
|
||||
|
||||
|
||||
def _run_da3(model_patcher, image: torch.Tensor, process_res: int, method: str = "upper_bound_resize"):
|
||||
"""Run DA3 on (B,H,W,3), returns depth/conf/sky at original resolution (or None)."""
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
|
||||
B, H, W, _ = image.shape
|
||||
mm.load_model_gpu(model_patcher)
|
||||
diffusion = model_patcher.model.diffusion_model
|
||||
device = mm.get_torch_device()
|
||||
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
|
||||
|
||||
depths, confs, skies = [], [], []
|
||||
for i in range(B):
|
||||
single = image[i:i + 1].to(device)
|
||||
x = da3_preprocess.preprocess_image(single, process_res=process_res, method=method)
|
||||
x = x.to(dtype=dtype)
|
||||
with torch.no_grad():
|
||||
out = diffusion(x)
|
||||
|
||||
depth_lr = out["depth"]
|
||||
depth_full = torch.nn.functional.interpolate(
|
||||
depth_lr.unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
depths.append(depth_full)
|
||||
|
||||
if "depth_conf" in out:
|
||||
conf_full = torch.nn.functional.interpolate(
|
||||
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
confs.append(conf_full)
|
||||
if "sky" in out:
|
||||
sky_full = torch.nn.functional.interpolate(
|
||||
out["sky"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
skies.append(sky_full)
|
||||
|
||||
depth = torch.cat(depths, dim=0)
|
||||
confidence = torch.cat(confs, dim=0) if confs else None
|
||||
sky = torch.cat(skies, dim=0) if skies else None
|
||||
return depth, confidence, sky
|
||||
|
||||
|
||||
class DA3Inference(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3Inference",
|
||||
search_aliases=["depth", "geometry", "da3", "depth anything", "monocular", "pointmap", "sky", "3d", "metric depth", "disparity"],
|
||||
display_name="Run Depth Anything 3",
|
||||
category="image/geometry estimation",
|
||||
description="Run Depth Anything 3 on an image. In multi-view mode each image is treated as a separate view of the same scene.",
|
||||
inputs=[
|
||||
DA3ModelType.Input("da3_model"),
|
||||
io.Image.Input("image"),
|
||||
io.Int.Input("resolution", default=504, min=140, max=2520, step=14,
|
||||
tooltip="Resolution the model runs at (longest side, multiple of 14).\n"
|
||||
"Lower = faster / less VRAM.\n"
|
||||
"Higher = more detail.\n"
|
||||
"Output is upsampled back to the original size."),
|
||||
io.Combo.Input("resize_method", options=["upper_bound_resize", "lower_bound_resize"], default="upper_bound_resize",
|
||||
tooltip="upper_bound_resize: scale so the longest side = resolution (caps memory, default).\n"
|
||||
"lower_bound_resize: scale so the shortest side = resolution (preserves more detail on tall/wide images, uses more memory)."),
|
||||
io.DynamicCombo.Input("mode", tooltip="mono: single view image (works with any model variant).\n"
|
||||
"multiview: all images processed together for geometric consistency + camera pose (for Small/Base models only).",
|
||||
options=[
|
||||
io.DynamicCombo.Option("mono", []),
|
||||
io.DynamicCombo.Option("multiview", [
|
||||
io.Combo.Input("ref_view_strategy", options=["saddle_balanced", "saddle_sim_range", "first", "middle"], default="saddle_balanced",
|
||||
tooltip="Which view acts as the geometric anchor.\n"
|
||||
"- saddle_balanced: the view most 'average' across all others (best general choice).\n"
|
||||
"- saddle_sim_range: the view most visually distinct from the others.\n"
|
||||
"- first / middle: fixed positional picks."),
|
||||
io.Combo.Input("pose_method", options=["cam_dec", "ray_pose"], default="cam_dec",
|
||||
tooltip="How the camera field-of-view is estimated (for Small/Base models only).\n"
|
||||
"- cam_dec: learned from image features.\n"
|
||||
"- ray_pose: derived geometrically from the model's 3D ray output.\n"
|
||||
"Affects perspective correctness of the 3D output. Try both if results look distorted."),
|
||||
]),
|
||||
]),
|
||||
],
|
||||
outputs=[
|
||||
DA3Geometry.Output("da3_geometry", tooltip="Dictionary of non-normalized tensors.\n"
|
||||
"Always has the keys: depth, image, mode.\n"
|
||||
"Optional keys: sky (for Mono/Metric), confidence (for Small/Base), extrinsics + intrinsics (for multi-view)."),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_model, image, resolution, resize_method, mode) -> io.NodeOutput:
|
||||
mode_val = mode["mode"] # "mono" or "multiview"
|
||||
|
||||
if mode_val == "mono":
|
||||
return cls._execute_mono(da3_model, image, resolution, resize_method)
|
||||
|
||||
# Capability checks for multi-view mode.
|
||||
diffusion = da3_model.model.diffusion_model
|
||||
pose_method = mode["pose_method"]
|
||||
ref_view_strategy = mode["ref_view_strategy"]
|
||||
|
||||
has_cam_dec = diffusion.cam_dec is not None
|
||||
has_dualdpt = diffusion.head_type == "dualdpt"
|
||||
|
||||
if not has_cam_dec and not has_dualdpt:
|
||||
raise ValueError(
|
||||
"multi-view mode requires Small or Base model. The loaded model "
|
||||
f"(head_type='{diffusion.head_type}') does not support cross-view "
|
||||
"attention or camera pose estimation. Switch mode to 'mono', or "
|
||||
"load Small or Base model for mult-view."
|
||||
)
|
||||
|
||||
if pose_method == "cam_dec" and not has_cam_dec:
|
||||
raise ValueError(
|
||||
"pose_method='cam_dec' requires a camera decoder, but the loaded "
|
||||
f"model (head_type='{diffusion.head_type}') does not have one. "
|
||||
"Use pose_method='ray_pose' instead."
|
||||
)
|
||||
if pose_method == "ray_pose" and not has_dualdpt:
|
||||
raise ValueError(
|
||||
"pose_method='ray_pose' requires a DualDPT head, but the loaded "
|
||||
f"model has a '{diffusion.head_type}' head. "
|
||||
"Use pose_method='cam_dec' instead."
|
||||
)
|
||||
|
||||
return cls._execute_multiview(
|
||||
da3_model, image, resolution, resize_method,
|
||||
ref_view_strategy, pose_method,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _execute_mono(cls, model, image, resolution, resize_method) -> io.NodeOutput:
|
||||
depth, confidence, sky = _run_da3(model, image, resolution, method=resize_method)
|
||||
|
||||
geometry: dict = {
|
||||
"depth": depth.contiguous(),
|
||||
"image": image[..., :3].cpu(),
|
||||
"mode": "mono",
|
||||
}
|
||||
if sky is not None:
|
||||
geometry["sky"] = sky.contiguous()
|
||||
if confidence is not None:
|
||||
geometry["confidence"] = confidence.contiguous()
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
@classmethod
|
||||
def _execute_multiview(cls, model, image, resolution, resize_method, ref_view_strategy, pose_method) -> io.NodeOutput:
|
||||
assert image.ndim == 4 and image.shape[-1] == 3, \
|
||||
f"expected (B,H,W,3) IMAGE; got {tuple(image.shape)}"
|
||||
S, H, W, _ = image.shape
|
||||
|
||||
mm.load_model_gpu(model)
|
||||
diffusion = model.model.diffusion_model
|
||||
device = mm.get_torch_device()
|
||||
dtype = diffusion.dtype if diffusion.dtype is not None else torch.float32
|
||||
|
||||
# All views in a single forward pass: (1, S, 3, H', W').
|
||||
x = image.to(device)
|
||||
x = da3_preprocess.preprocess_image(x, process_res=resolution, method=resize_method)
|
||||
x = x.to(dtype=dtype).unsqueeze(0)
|
||||
|
||||
use_ray_pose = (pose_method == "ray_pose")
|
||||
with torch.no_grad():
|
||||
out = diffusion(x, use_ray_pose=use_ray_pose, ref_view_strategy=ref_view_strategy)
|
||||
|
||||
depth = torch.nn.functional.interpolate(
|
||||
out["depth"].float().unsqueeze(1), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
sky = None
|
||||
if "sky" in out:
|
||||
sky = torch.nn.functional.interpolate(
|
||||
out["sky"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
|
||||
if "extrinsics" in out and "intrinsics" in out:
|
||||
extrinsics = out["extrinsics"].float().cpu()
|
||||
intrinsics = out["intrinsics"].float().cpu()
|
||||
else:
|
||||
extrinsics = torch.eye(4)[None, None].expand(1, S, 4, 4).clone()
|
||||
intrinsics = torch.eye(3)[None, None].expand(1, S, 3, 3).clone()
|
||||
|
||||
geometry: dict = {
|
||||
"depth": depth.contiguous(),
|
||||
"image": image[..., :3].cpu(),
|
||||
"mode": "multiview",
|
||||
"extrinsics": extrinsics.contiguous(),
|
||||
"intrinsics": intrinsics.contiguous(),
|
||||
}
|
||||
if sky is not None:
|
||||
geometry["sky"] = sky.contiguous()
|
||||
if "depth_conf" in out:
|
||||
conf = torch.nn.functional.interpolate(
|
||||
out["depth_conf"].unsqueeze(1).float(), size=(H, W),
|
||||
mode="bilinear", align_corners=False,
|
||||
).squeeze(1).cpu()
|
||||
geometry["confidence"] = conf.contiguous()
|
||||
return io.NodeOutput(geometry)
|
||||
|
||||
|
||||
class DA3Render(io.ComfyNode):
|
||||
"""Render a visualization from a DA3_GEOMETRY packet."""
|
||||
|
||||
_DEPTH_RENDER_INPUTS = [
|
||||
io.Combo.Input("normalization",
|
||||
options=["v2_style", "min_max", "raw"],
|
||||
default="v2_style",
|
||||
tooltip="- v2_style: mean/std normalisation for perceptually balanced results (default).\n"
|
||||
"- min_max: stretches the full depth range to [0, 1] for maximum contrast.\n"
|
||||
"- raw: no scaling,preserves metric units for Metric model."),
|
||||
io.Boolean.Input("apply_sky_clip", default=False,
|
||||
tooltip="Clip sky-region depth to the 99th percentile of foreground depth before normalisation. "
|
||||
"Requires a sky key in the da3_geometry input (for Mono/Metric models only)."),
|
||||
]
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3Render",
|
||||
display_name="Render Depth Anything 3",
|
||||
category="image/geometry estimation",
|
||||
description="Render a depth map, confidence map, or sky mask from Depth Anything 3 geometry data.",
|
||||
inputs=[
|
||||
DA3Geometry.Input("da3_geometry"),
|
||||
io.DynamicCombo.Input("output",
|
||||
tooltip="- depth: normalised greyscale depth image.\n"
|
||||
"- depth_colored: depth mapped through the Turbo colormap.\n"
|
||||
"- sky_mask: sky probability in [0, 1] (for Mono/Metric models only).\n"
|
||||
"- confidence: normalised depth confidence (for Small/Base models only).",
|
||||
options=[
|
||||
io.DynamicCombo.Option("depth", cls._DEPTH_RENDER_INPUTS),
|
||||
io.DynamicCombo.Option("depth_colored", cls._DEPTH_RENDER_INPUTS),
|
||||
io.DynamicCombo.Option("sky_mask", [
|
||||
io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the sky mask."),
|
||||
]),
|
||||
io.DynamicCombo.Option("confidence", [
|
||||
io.Boolean.Input("colored", default=False, tooltip="Apply the Turbo colormap to the confidence map."),
|
||||
]),
|
||||
]),
|
||||
],
|
||||
outputs=[io.Image.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_geometry, output) -> io.NodeOutput:
|
||||
output_val = output["output"]
|
||||
|
||||
if output_val in ("depth", "depth_colored"):
|
||||
normalization = output["normalization"]
|
||||
apply_sky_clip = output["apply_sky_clip"]
|
||||
if apply_sky_clip and "sky" not in da3_geometry:
|
||||
raise ValueError(
|
||||
"apply_sky_clip=True requires a sky tensor in the da3_geometry input, but none is present. "
|
||||
"Run with Mono/Metric models or set apply_sky_clip=False."
|
||||
)
|
||||
depth = da3_geometry["depth"]
|
||||
sky = da3_geometry.get("sky")
|
||||
if apply_sky_clip and sky is not None:
|
||||
depth = torch.stack([
|
||||
da3_preprocess.apply_sky_aware_clip(depth[i], sky[i])
|
||||
for i in range(depth.shape[0])
|
||||
], dim=0)
|
||||
grey = cls._depth_to_image(depth, sky, normalization) # (B,H,W,3) greyscale
|
||||
result = _turbo(grey[..., 0]) if output_val == "depth_colored" else grey
|
||||
|
||||
elif output_val == "sky_mask":
|
||||
if "sky" not in da3_geometry:
|
||||
raise ValueError("geometry has no sky output; run with Mono/Metric models.")
|
||||
sky = da3_geometry["sky"]
|
||||
if output["colored"]:
|
||||
result = _turbo(sky)
|
||||
else:
|
||||
result = sky.unsqueeze(-1).expand(*sky.shape, 3).contiguous()
|
||||
|
||||
elif output_val == "confidence":
|
||||
if "confidence" not in da3_geometry:
|
||||
raise ValueError("da3_geometry has no confidence output; run with Small/Base models.")
|
||||
conf = _normalize_confidence(da3_geometry["confidence"])
|
||||
if output["colored"]:
|
||||
result = _turbo(conf)
|
||||
else:
|
||||
result = conf.unsqueeze(-1).expand(*conf.shape, 3).contiguous()
|
||||
|
||||
else:
|
||||
raise ValueError(f"Unknown output mode: {output_val}")
|
||||
|
||||
return io.NodeOutput(result.float())
|
||||
|
||||
@staticmethod
|
||||
def _depth_to_image(depth: torch.Tensor, sky_for_norm: torch.Tensor | None, normalization: str) -> torch.Tensor:
|
||||
"""Normalise depth and pack as an (B,H,W,3) image tensor."""
|
||||
|
||||
N = depth.shape[0]
|
||||
if normalization == "v2_style":
|
||||
norm = torch.stack([
|
||||
da3_preprocess.normalize_depth_v2_style(
|
||||
depth[i], sky_for_norm[i] if sky_for_norm is not None else None)
|
||||
for i in range(N)
|
||||
], dim=0)
|
||||
elif normalization == "min_max":
|
||||
norm = da3_preprocess.normalize_depth_min_max(depth)
|
||||
else:
|
||||
norm = depth
|
||||
|
||||
out = norm.unsqueeze(-1).repeat(1, 1, 1, 3)
|
||||
if normalization != "raw":
|
||||
out = out.clamp(0.0, 1.0)
|
||||
return out.contiguous()
|
||||
|
||||
|
||||
class DA3GeometryToMesh(io.ComfyNode):
|
||||
"""Convert a DA3_GEOMETRY packet into a Types.MESH by unprojecting depth and triangulating."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3GeometryToMesh",
|
||||
search_aliases=["da3", "depth anything", "mesh", "geometry", "3d", "triangulate"],
|
||||
display_name="Convert DA3 Geometry to Mesh",
|
||||
category="image/geometry estimation",
|
||||
description="Convert a depth map into a triangulated 3D mesh.",
|
||||
inputs=[
|
||||
DA3Geometry.Input("da3_geometry"),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert. Per-image vertex counts differ so batches cannot be stacked."),
|
||||
io.Int.Input("decimation", default=1, min=1, max=8, tooltip="Vertex stride. 1 = full resolution, 2 = half, etc."),
|
||||
io.Float.Input("discontinuity_threshold", default=0.04, min=0.0, max=1.0, step=0.01, tooltip="Drop triangles whose 3x3 depth span exceeds this fraction. 0 = off."),
|
||||
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all, 1 = keep only the single most confident pixel). "
|
||||
"Used when the geometry has a confidence map (Small/Base models)."),
|
||||
io.Boolean.Input("use_sky_mask", default=True, tooltip="Exclude sky-probability pixels (sky >= 0.5) from the mesh. Used when the geometry has a sky map (Mono/Metric models)."),
|
||||
io.Boolean.Input("texture", default=True, tooltip="Use the source image as a base color texture."),
|
||||
],
|
||||
outputs=[io.Mesh.Output()],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_geometry, batch_index, decimation, discontinuity_threshold, confidence_threshold, use_sky_mask, texture) -> io.NodeOutput:
|
||||
depth_all = da3_geometry["depth"] # (B, H, W)
|
||||
B = depth_all.shape[0]
|
||||
if batch_index >= B:
|
||||
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
|
||||
|
||||
depth = depth_all[batch_index] # (H, W)
|
||||
H, W = depth.shape
|
||||
|
||||
# NaN/inf depth would propagate silently through unproject and produce an
|
||||
# empty mesh; replace them with 0 here so those pixels are later excluded
|
||||
# by the isfinite check inside triangulate_grid_mesh.
|
||||
depth = depth.clone()
|
||||
n_bad = (~torch.isfinite(depth)).sum().item()
|
||||
if n_bad:
|
||||
logging.getLogger("comfy").warning(
|
||||
f"DA3GeometryToMesh: depth[{batch_index}] has {n_bad} non-finite pixels "
|
||||
f"({100*n_bad/(H*W):.1f}%) - zeroed before unproject."
|
||||
)
|
||||
depth[~torch.isfinite(depth)] = 0.0
|
||||
logging.getLogger("comfy").debug(
|
||||
f"DA3GeometryToMesh: depth[{batch_index}] range "
|
||||
f"[{depth.min():.4g}, {depth.max():.4g}], mean={depth.mean():.4g}"
|
||||
)
|
||||
|
||||
K = _da3_get_K(da3_geometry, batch_index, H, W)
|
||||
points = _da3_unproject(depth, K) # (H, W, 3) in OpenCV camera space
|
||||
|
||||
# Apply world-to-camera inverse so multi-view frames share a common world frame.
|
||||
E = _da3_get_extrinsic(da3_geometry, batch_index)
|
||||
if E is not None:
|
||||
points = _da3_apply_extrinsic(points, E)
|
||||
|
||||
# Mask invalid pixels by setting them to inf so triangulate_grid_mesh skips them.
|
||||
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
|
||||
# Also exclude pixels where depth was invalid.
|
||||
mask = mask & (depth_all[batch_index] > 0) & torch.isfinite(depth_all[batch_index])
|
||||
points = points.clone()
|
||||
points[~mask] = float('inf')
|
||||
|
||||
verts, faces, uvs = triangulate_grid_mesh(
|
||||
points,
|
||||
decimation=decimation,
|
||||
discontinuity_threshold=discontinuity_threshold,
|
||||
depth=depth,
|
||||
)
|
||||
if verts.shape[0] == 0 or faces.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"DA3GeometryToMesh produced an empty mesh. "
|
||||
"Try raising discontinuity_threshold, lowering confidence_threshold, "
|
||||
"or disabling use_sky_mask."
|
||||
)
|
||||
|
||||
# OpenCV (X right, Y down, Z forward) → glTF (X right, Y up, Z back).
|
||||
# Same transform as MoGePointMapToMesh perspective branch.
|
||||
verts = verts * torch.tensor([1.0, -1.0, -1.0], dtype=verts.dtype)
|
||||
faces = faces[:, [0, 2, 1]].contiguous()
|
||||
|
||||
tex = da3_geometry["image"][batch_index:batch_index + 1] if texture else None
|
||||
mesh = Types.MESH(
|
||||
vertices=verts.unsqueeze(0),
|
||||
faces=faces.unsqueeze(0),
|
||||
uvs=uvs.unsqueeze(0),
|
||||
texture=tex,
|
||||
)
|
||||
return io.NodeOutput(mesh)
|
||||
|
||||
|
||||
class DA3GeometryToPointCloud(io.ComfyNode):
|
||||
"""Unproject a DA3_GEOMETRY depth map into a filtered DA3_POINT_CLOUD."""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="DA3GeometryToPointCloud",
|
||||
search_aliases=["da3", "depth anything", "point cloud", "pointcloud", "3d", "geometry"],
|
||||
display_name="Convert DA3 Geometry to Point Cloud",
|
||||
category="image/geometry estimation",
|
||||
description="Convert a depth map into a 3D point cloud.",
|
||||
inputs=[
|
||||
DA3Geometry.Input("da3_geometry"),
|
||||
io.Int.Input("batch_index", default=0, min=0, max=4096, tooltip="Which image of a batch to convert."),
|
||||
io.Float.Input("confidence_threshold", default=0.1, min=0.0, max=1.0, step=0.01,
|
||||
tooltip="Exclude pixels whose per-image normalised confidence is below this value (0 = keep all). Used when the geometry has a confidence map (Small/Base models)."),
|
||||
io.Boolean.Input("use_sky_mask", default=True,
|
||||
tooltip="Exclude sky-probability pixels (sky >= 0.5). Used when the geometry has a sky map (Mono/Metric models)."),
|
||||
io.Int.Input("downsample", default=1, min=1, max=16,
|
||||
tooltip="Take every Nth pixel (1 = full resolution). Higher values give fewer points and faster processing."),
|
||||
],
|
||||
# TODO: add a proper PointCloud output type
|
||||
outputs=[DA3PointCloud.Output(display_name="point_cloud")],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, da3_geometry, batch_index, confidence_threshold, use_sky_mask, downsample) -> io.NodeOutput:
|
||||
depth_all = da3_geometry["depth"] # (B, H, W)
|
||||
B = depth_all.shape[0]
|
||||
if batch_index >= B:
|
||||
raise ValueError(f"batch_index {batch_index} is out of range; DA3_GEOMETRY has batch size {B}.")
|
||||
|
||||
depth = depth_all[batch_index].clone() # (H, W)
|
||||
depth[~torch.isfinite(depth)] = 0.0
|
||||
H, W = depth.shape
|
||||
|
||||
K = _da3_get_K(da3_geometry, batch_index, H, W)
|
||||
|
||||
if downsample > 1:
|
||||
depth = depth[::downsample, ::downsample].contiguous()
|
||||
# Scale intrinsics to the downsampled grid.
|
||||
K = K.clone()
|
||||
K[0, :] /= downsample
|
||||
K[1, :] /= downsample
|
||||
|
||||
H_ds, W_ds = depth.shape
|
||||
points = _da3_unproject(depth, K) # (H_ds, W_ds, 3) in OpenCV camera space
|
||||
|
||||
# Apply world-to-camera inverse so multi-view frames share a common world frame.
|
||||
E = _da3_get_extrinsic(da3_geometry, batch_index)
|
||||
if E is not None:
|
||||
points = _da3_apply_extrinsic(points, E)
|
||||
|
||||
# Rebuild mask at downsampled resolution.
|
||||
mask = _da3_build_mask(da3_geometry, batch_index, H, W, confidence_threshold, use_sky_mask)
|
||||
if downsample > 1:
|
||||
mask = mask[::downsample, ::downsample]
|
||||
|
||||
mask = mask & torch.isfinite(depth)
|
||||
|
||||
# OpenCV → glTF: flip Y and Z.
|
||||
points_gltf = points.clone()
|
||||
points_gltf[..., 1] *= -1.0
|
||||
points_gltf[..., 2] *= -1.0
|
||||
|
||||
pts_flat = points_gltf.reshape(-1, 3)[mask.reshape(-1)]
|
||||
|
||||
colors_flat = None
|
||||
if "image" in da3_geometry:
|
||||
img = da3_geometry["image"][batch_index] # (H, W, 3)
|
||||
if downsample > 1:
|
||||
img = img[::downsample, ::downsample]
|
||||
colors_flat = img.reshape(-1, 3)[mask.reshape(-1)]
|
||||
|
||||
conf_flat = None
|
||||
if "confidence" in da3_geometry:
|
||||
conf = da3_geometry["confidence"][batch_index] # (H, W)
|
||||
if downsample > 1:
|
||||
conf = conf[::downsample, ::downsample]
|
||||
conf_flat = conf.reshape(-1)[mask.reshape(-1)]
|
||||
|
||||
if pts_flat.shape[0] == 0:
|
||||
raise ValueError(
|
||||
"DA3GeometryToPointCloud produced zero points after filtering. "
|
||||
"Try lowering confidence_threshold or disabling use_sky_mask."
|
||||
)
|
||||
|
||||
return io.NodeOutput({
|
||||
"points": pts_flat,
|
||||
"colors": colors_flat,
|
||||
"confidence": conf_flat,
|
||||
})
|
||||
|
||||
|
||||
class DA3Extension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
LoadDA3Model,
|
||||
DA3Inference,
|
||||
DA3Render,
|
||||
DA3GeometryToMesh,
|
||||
# DA3GeometryToPointCloud, # Keep this commented out for now until we have a proper PointCloud output type
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> DA3Extension:
|
||||
return DA3Extension()
|
||||
@ -245,6 +245,11 @@ class KV_Attn_Input:
|
||||
cache_key = "{}_{}".format(extra_options["block_type"], extra_options["block_index"])
|
||||
if cache_key in self.cache:
|
||||
kk, vv = self.cache[cache_key]
|
||||
|
||||
# Fix batch size changing.
|
||||
kk = comfy.utils.repeat_to_batch_size(kk, k.shape[0])
|
||||
vv = comfy.utils.repeat_to_batch_size(vv, v.shape[0])
|
||||
|
||||
self.set_cache = False
|
||||
return {"q": q, "k": torch.cat((k, kk), dim=2), "v": torch.cat((v, vv), dim=2)}
|
||||
|
||||
|
||||
@ -8,6 +8,7 @@ import folder_paths
|
||||
from comfy_api.latest import ComfyExtension, Types, io
|
||||
from typing_extensions import override
|
||||
|
||||
from comfy.ldm.colormap import turbo as _turbo
|
||||
from comfy.ldm.moge.model import MoGeModel
|
||||
from comfy.ldm.moge.geometry import triangulate_grid_mesh
|
||||
from comfy.ldm.moge.panorama import get_panorama_cameras, split_panorama_image, merge_panorama_depth, spherical_uv_to_directions, _uv_grid
|
||||
@ -27,19 +28,6 @@ MoGeGeometry = io.Custom("MOGE_GEOMETRY")
|
||||
# "image": torch.Tensor (B, H, W, 3) in [0, 1], CPU (always present)
|
||||
|
||||
|
||||
def _turbo(x: torch.Tensor) -> torch.Tensor:
|
||||
"""Anton Mikhailov polynomial approximation of the turbo colormap."""
|
||||
x = x.clamp(0.0, 1.0)
|
||||
x2 = x * x
|
||||
x3 = x2 * x
|
||||
x4 = x2 * x2
|
||||
x5 = x4 * x
|
||||
r = 0.13572138 + 4.61539260*x - 42.66032258*x2 + 132.13108234*x3 - 152.94239396*x4 + 59.28637943*x5
|
||||
g = 0.09140261 + 2.19418839*x + 4.84296658*x2 - 14.18503333*x3 + 4.27729857*x4 + 2.82956604*x5
|
||||
b = 0.10667330 + 12.64194608*x - 60.58204836*x2 + 110.36276771*x3 - 89.90310912*x4 + 27.34824973*x5
|
||||
return torch.stack([r, g, b], dim=-1).clamp(0.0, 1.0)
|
||||
|
||||
|
||||
def _normals_from_points(points: torch.Tensor) -> torch.Tensor:
|
||||
"""Camera-space surface normals from a (B, H, W, 3) point map (v1 fallback)."""
|
||||
finite = torch.isfinite(points).all(dim=-1)
|
||||
|
||||
325
comfy_extras/nodes_scail.py
Normal file
325
comfy_extras/nodes_scail.py
Normal file
@ -0,0 +1,325 @@
|
||||
"""SCAIL / SCAIL-2 nodes: the WanSCAILToVideo conditioning node and the SAM3
|
||||
preprocessing that turns video tracks into the bundle the SCAIL-2 model consumes."""
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
|
||||
import nodes
|
||||
import node_helpers
|
||||
import comfy.model_management
|
||||
import comfy.utils
|
||||
from comfy_api.latest import ComfyExtension, io
|
||||
from comfy.ldm.sam3.tracker import unpack_masks
|
||||
|
||||
SAM3TrackData = io.Custom("SAM3_TRACK_DATA")
|
||||
|
||||
|
||||
# Model was trained on these exact colors; deviating degrades multi-identity quality.
|
||||
DEFAULT_PALETTE = [
|
||||
(0.0, 0.0, 1.0), # Blue
|
||||
(1.0, 0.0, 0.0), # Red
|
||||
(0.0, 1.0, 0.0), # Green
|
||||
(1.0, 0.0, 1.0), # Magenta
|
||||
(0.0, 1.0, 1.0), # Cyan
|
||||
(1.0, 1.0, 0.0), # Yellow
|
||||
]
|
||||
|
||||
|
||||
def _unpack(track_data):
|
||||
packed = track_data["packed_masks"]
|
||||
if packed is None or packed.shape[1] == 0:
|
||||
return None
|
||||
return unpack_masks(packed)
|
||||
|
||||
|
||||
def _first_frame_cx_area(masks_bool):
|
||||
first = masks_bool[0].float()
|
||||
H, W = first.shape[-2], first.shape[-1]
|
||||
n_pixels = H * W
|
||||
grid_x = torch.arange(W, device=first.device, dtype=first.dtype).view(1, W)
|
||||
area = first.sum(dim=(-1, -2)).clamp_(min=1)
|
||||
cx = (first * grid_x).sum(dim=(-1, -2)) / area
|
||||
return (cx / W).tolist(), (area / n_pixels).tolist()
|
||||
|
||||
|
||||
def _subset_track_data(track_data, obj_indices):
|
||||
out = dict(track_data)
|
||||
packed = track_data["packed_masks"]
|
||||
if packed is None or not obj_indices:
|
||||
out["packed_masks"] = None
|
||||
if "scores" in out:
|
||||
out["scores"] = []
|
||||
return out
|
||||
out["packed_masks"] = packed[:, obj_indices].contiguous()
|
||||
scores = track_data.get("scores")
|
||||
if scores is not None:
|
||||
out["scores"] = [scores[i] for i in obj_indices if i < len(scores)]
|
||||
return out
|
||||
|
||||
|
||||
def _render_colored_masks(track_data, background="black"):
|
||||
packed = track_data["packed_masks"]
|
||||
H, W = track_data["orig_size"]
|
||||
device = comfy.model_management.intermediate_device()
|
||||
dtype = comfy.model_management.intermediate_dtype()
|
||||
bg_rgb = (1.0, 1.0, 1.0) if background.startswith("white") else (0.0, 0.0, 0.0)
|
||||
if packed is None or packed.shape[1] == 0:
|
||||
T = track_data.get("n_frames", 1) if packed is None else packed.shape[0]
|
||||
out = torch.empty(T, H, W, 3, device=device, dtype=dtype)
|
||||
out[..., 0], out[..., 1], out[..., 2] = bg_rgb[0], bg_rgb[1], bg_rgb[2]
|
||||
return out
|
||||
T, N_obj = packed.shape[0], packed.shape[1]
|
||||
colors = torch.tensor(
|
||||
[DEFAULT_PALETTE[i % len(DEFAULT_PALETTE)] for i in range(N_obj)],
|
||||
device=device, dtype=dtype,
|
||||
)
|
||||
masks_full = unpack_masks(packed.to(device)).float()
|
||||
Hm, Wm = masks_full.shape[-2], masks_full.shape[-1]
|
||||
masks_full = F.interpolate(
|
||||
masks_full.view(T * N_obj, 1, Hm, Wm), size=(H, W), mode="nearest"
|
||||
).view(T, N_obj, H, W) > 0.5
|
||||
any_mask = masks_full.any(dim=1)
|
||||
obj_idx_map = masks_full.to(torch.uint8).argmax(dim=1)
|
||||
color_overlay = colors[obj_idx_map]
|
||||
bg_tensor = torch.tensor(bg_rgb, device=device, dtype=color_overlay.dtype).view(1, 1, 1, 3)
|
||||
return torch.where(any_mask.unsqueeze(-1), color_overlay, bg_tensor.expand_as(color_overlay))
|
||||
|
||||
|
||||
def _extract_mask_to_28ch(rgb_video):
|
||||
"""Colored RGB mask (T, H, W, 3) in [0, 1] -> SCAIL-2 28-channel binary latent
|
||||
(1, T_lat, 28, H_lat, W_lat). 7 per-color binary channels (white/r/g/b/y/m/c)
|
||||
threshold-extracted at 225/255, 8x spatial downsample, 4-frame temporal stacking."""
|
||||
T, H, W, _ = rgb_video.shape
|
||||
_ON_THRESH = 225.0 / 255.0
|
||||
mask = rgb_video.movedim(-1, 1).float()
|
||||
R = (mask[:, 0:1] > _ON_THRESH).float()
|
||||
G = (mask[:, 1:2] > _ON_THRESH).float()
|
||||
B = (mask[:, 2:3] > _ON_THRESH).float()
|
||||
nR, nG, nB = 1 - R, 1 - G, 1 - B
|
||||
binary_7ch = torch.cat([
|
||||
R * G * B, # white
|
||||
R * nG * nB, # red
|
||||
nR * G * nB, # green
|
||||
nR * nG * B, # blue
|
||||
R * G * nB, # yellow
|
||||
R * nG * B, # magenta
|
||||
nR * G * B, # cyan
|
||||
], dim=1)
|
||||
H_lat, W_lat = H, W
|
||||
for _ in range(3):
|
||||
H_lat = (H_lat + 1) // 2
|
||||
W_lat = (W_lat + 1) // 2
|
||||
binary_7ch = torch.nn.functional.interpolate(binary_7ch, size=(H_lat, W_lat), mode='area')
|
||||
T_latent = (T - 1) // 4 + 1
|
||||
padded = torch.cat([binary_7ch[:1].repeat(4, 1, 1, 1), binary_7ch[1:]], dim=0)
|
||||
out = padded.view(T_latent, 28, H_lat, W_lat)
|
||||
return out.unsqueeze(0)
|
||||
|
||||
|
||||
class WanSCAILToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanSCAILToVideo",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
|
||||
io.Image.Input("pose_video_mask", optional=True, tooltip="SCAIL-2 only. Colored per-identity SAM3 mask video at the same resolution as pose_video."),
|
||||
io.Boolean.Input("replacement_mode", default=False, optional=True, tooltip="SCAIL-2 only. False = Animation Mode (pose_video_mask should have black background). True = Replacement Mode (pose_video_mask should have white background)."),
|
||||
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
|
||||
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step of the pose conditioning."),
|
||||
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step of the pose conditioning."),
|
||||
io.Image.Input("reference_image", optional=True, tooltip="Reference image, for multiple references composite all on single image."),
|
||||
io.Image.Input("reference_image_mask", optional=True, tooltip="SCAIL-2 only. Colored reference mask at the same resolution as reference_image."),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True, tooltip="CLIP vision features for conditioning. Model is trained with stretch resize to aspect ratio."),
|
||||
io.Int.Input("video_frame_offset", default=0, min=0, max=nodes.MAX_RESOLUTION, step=1, tooltip="Cumulative output frame this chunk begins at. Wire from the previous chunk's video_frame_offset output."),
|
||||
io.Int.Input("previous_frame_count", default=5, min=1, max=nodes.MAX_RESOLUTION, step=4, tooltip="Tail frames of previous_frames to anchor. SCAIL-2 trained at 5 (81-frame chunks, 76-frame step)."),
|
||||
io.Image.Input("previous_frames", optional=True, tooltip="SCAIL-2 only. Full decoded output of the previous chunk. Only the last previous_frame_count are used as the extension anchor."),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
|
||||
io.Int.Output(display_name="video_frame_offset", tooltip="Adjusted offset + length. Wire into the next chunk."),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end,
|
||||
video_frame_offset, previous_frame_count, replacement_mode=False, reference_image=None, clip_vision_output=None, pose_video=None,
|
||||
pose_video_mask=None, reference_image_mask=None, previous_frames=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
noise_mask = None
|
||||
|
||||
ref_mask_flag = not replacement_mode
|
||||
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_flag": ref_mask_flag})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_flag": ref_mask_flag})
|
||||
|
||||
prev_trimmed = None
|
||||
if previous_frames is not None and previous_frames.shape[0] > 0:
|
||||
prev_trimmed = previous_frames[-previous_frame_count:]
|
||||
video_frame_offset -= prev_trimmed.shape[0]
|
||||
video_frame_offset = max(0, video_frame_offset)
|
||||
|
||||
ref_latent = None
|
||||
if reference_image is not None:
|
||||
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
|
||||
# Replacement Mode: composite ref on black bg using reference_image_mask as alpha matte
|
||||
if replacement_mode and reference_image_mask is not None:
|
||||
rm = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "nearest-exact", "center").movedim(1, -1)
|
||||
is_char = (rm[..., :3].max(dim=-1, keepdim=True).values > 0.1).to(reference_image.dtype)
|
||||
reference_image = reference_image * is_char
|
||||
ref_latent = vae.encode(reference_image[:, :, :, :3])
|
||||
|
||||
if ref_latent is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [ref_latent]}, append=True)
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
if pose_video is not None:
|
||||
if pose_video.shape[0] <= video_frame_offset:
|
||||
pose_video = None
|
||||
else:
|
||||
pose_video = pose_video[video_frame_offset:]
|
||||
if pose_video_mask is not None:
|
||||
if pose_video_mask.shape[0] <= video_frame_offset:
|
||||
pose_video_mask = None
|
||||
else:
|
||||
pose_video_mask = pose_video_mask[video_frame_offset:]
|
||||
|
||||
# Truncate pose+mask jointly to the shorter of the two, capped at length.
|
||||
ts = [v.shape[0] for v in (pose_video, pose_video_mask) if v is not None]
|
||||
if ts:
|
||||
T_kept = ((min(min(ts), length) - 1) // 4) * 4 + 1
|
||||
if pose_video is not None:
|
||||
pose_video = pose_video[:T_kept]
|
||||
if pose_video_mask is not None:
|
||||
pose_video_mask = pose_video_mask[:T_kept]
|
||||
|
||||
if pose_video is not None:
|
||||
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
|
||||
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
|
||||
if pose_video_mask is not None:
|
||||
mask_video_hw = comfy.utils.common_upscale(pose_video_mask[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||
driving_mask_28ch = _extract_mask_to_28ch(mask_video_hw)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"driving_mask_28ch": driving_mask_28ch})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"driving_mask_28ch": driving_mask_28ch})
|
||||
|
||||
if reference_image_mask is not None:
|
||||
ref_mask_hw = comfy.utils.common_upscale(reference_image_mask[:1].movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
|
||||
ref_mask_1f = _extract_mask_to_28ch(ref_mask_hw)
|
||||
zeros = torch.zeros((1, latent.shape[2], 28, ref_mask_1f.shape[-2], ref_mask_1f.shape[-1]), device=ref_mask_1f.device, dtype=ref_mask_1f.dtype)
|
||||
ref_mask_28ch = torch.cat([ref_mask_1f, zeros], dim=1)
|
||||
positive = node_helpers.conditioning_set_values(positive, {"ref_mask_28ch": ref_mask_28ch})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"ref_mask_28ch": ref_mask_28ch})
|
||||
|
||||
if prev_trimmed is not None:
|
||||
pf = comfy.utils.common_upscale(prev_trimmed.movedim(-1, 1), width, height, "bicubic", "center").movedim(1, -1)
|
||||
prev_latent = vae.encode(pf[:, :, :, :3])
|
||||
prev_latent_frames = min(prev_latent.shape[2], latent.shape[2])
|
||||
latent[:, :, :prev_latent_frames] = prev_latent[:, :, :prev_latent_frames].to(latent.dtype)
|
||||
noise_mask = torch.ones((1, 1, latent.shape[2], latent.shape[-2], latent.shape[-1]), device=latent.device, dtype=latent.dtype)
|
||||
noise_mask[:, :, :prev_latent_frames] = 0.0
|
||||
|
||||
out_latent = {"samples": latent}
|
||||
if noise_mask is not None:
|
||||
out_latent["noise_mask"] = noise_mask
|
||||
return io.NodeOutput(positive, negative, out_latent, video_frame_offset + length)
|
||||
|
||||
|
||||
class SCAIL2ColoredMask(io.ComfyNode):
|
||||
"""Render SAM3 tracks for the driving pose video and (optionally) the reference
|
||||
image into the two colored masks WanSCAILToVideo consumes. Shared `sort_by`
|
||||
across both outputs guarantees identity K maps to the same color on both
|
||||
sides, for multi-person workflow consistency.
|
||||
reference_image_mask is always rendered black-bg (model convention)
|
||||
pose_video_mask bg follows replacement_mode: black = Animation Mode, white = Replacement Mode
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="SCAIL2ColoredMask",
|
||||
display_name="Create SCAIL-2 Colored Mask",
|
||||
category="conditioning/video_models/scail",
|
||||
inputs=[
|
||||
SAM3TrackData.Input("driving_track_data", tooltip="SAM3 track of the driving pose video. Will be rendered into the pose_video_mask output."),
|
||||
SAM3TrackData.Input("ref_track_data", optional=True,
|
||||
tooltip="SAM3 track of the reference image."),
|
||||
io.String.Input("object_indices", default="",
|
||||
tooltip="Comma-separated list of person indices to include (e.g. '0,2,3'). Applied to both reference and pose video masks. Empty = all."),
|
||||
io.Combo.Input("sort_by", options=["none", "left_to_right", "area"], default="left_to_right",
|
||||
tooltip="Order in which palette colors are assigned to the tracked objects (applied to both reference and pose video so each identity keeps the same color). left_to_right = leftmost object (by first-frame centroid) gets the first color; area = biggest object (by first-frame mask area) gets the first color; none = keep SAM3's order."),
|
||||
io.Boolean.Input("replacement_mode", default=False,
|
||||
tooltip="False = Animation Mode (pose_video_mask has black background, reference_image_mask has white background). "
|
||||
"True = Replacement Mode (pose_video_mask has white background, reference_image_mask has black background)."),
|
||||
],
|
||||
outputs=[
|
||||
io.Image.Output("pose_video_mask"),
|
||||
io.Image.Output("reference_image_mask"),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, driving_track_data, object_indices, sort_by, replacement_mode, ref_track_data=None):
|
||||
def _prep(td):
|
||||
masks_bool = _unpack(td)
|
||||
if sort_by != "none" and masks_bool is not None:
|
||||
cx, area = _first_frame_cx_area(masks_bool)
|
||||
if sort_by == "left_to_right":
|
||||
order = sorted(range(len(cx)), key=lambda i: cx[i])
|
||||
else: # "area"
|
||||
order = sorted(range(len(area)), key=lambda i: -area[i])
|
||||
td = _subset_track_data(td, order)
|
||||
if object_indices.strip():
|
||||
indices = [int(i.strip()) for i in object_indices.split(",") if i.strip().isdigit()]
|
||||
packed = td.get("packed_masks")
|
||||
n_obj = packed.shape[1] if packed is not None else 0
|
||||
indices = [i for i in indices if 0 <= i < n_obj]
|
||||
td = _subset_track_data(td, indices)
|
||||
return td
|
||||
|
||||
drv = _prep(driving_track_data)
|
||||
# Animation: driving=black, ref=white. Replacement: driving=white, ref=black.
|
||||
mask_video = _render_colored_masks(drv, "white" if replacement_mode else "black")
|
||||
ref_bg = "black" if replacement_mode else "white"
|
||||
|
||||
if ref_track_data is not None:
|
||||
ref = _prep(ref_track_data)
|
||||
reference_image_mask = _render_colored_masks(ref, ref_bg)
|
||||
else:
|
||||
H, W = drv["orig_size"]
|
||||
fill_value = 1.0 if ref_bg == "white" else 0.0
|
||||
reference_image_mask = torch.full((1, H, W, 3), fill_value, device=comfy.model_management.intermediate_device(), dtype=comfy.model_management.intermediate_dtype())
|
||||
|
||||
return io.NodeOutput(mask_video, reference_image_mask)
|
||||
|
||||
|
||||
class SCAILExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
return [
|
||||
WanSCAILToVideo,
|
||||
SCAIL2ColoredMask,
|
||||
]
|
||||
|
||||
|
||||
async def comfy_entrypoint() -> SCAILExtension:
|
||||
return SCAILExtension()
|
||||
@ -15,6 +15,7 @@ import comfy.sampler_helpers
|
||||
import comfy.sd
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy.conds import CONDRegular, CONDList
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import comfy_extras.nodes_custom_sampler
|
||||
import folder_paths
|
||||
@ -120,6 +121,11 @@ def process_cond_list(d, prefix=""):
|
||||
process_cond_list(v, f"{prefix}.{k}")
|
||||
elif isinstance(v, torch.Tensor):
|
||||
d[k] = v.clone()
|
||||
elif isinstance(v, CONDList):
|
||||
v.cond = [t.detach() if isinstance(t, torch.Tensor) else t for t in v.cond]
|
||||
elif isinstance(v, CONDRegular):
|
||||
if isinstance(v.cond, torch.Tensor):
|
||||
v.cond = v.cond.detach()
|
||||
elif isinstance(v, (list, tuple)):
|
||||
for index, item in enumerate(v):
|
||||
process_cond_list(item, f"{prefix}.{k}.{index}")
|
||||
@ -1143,45 +1149,45 @@ class TrainLoraNode(io.ComfyNode):
|
||||
# Process conditioning
|
||||
positive = _process_conditioning(positive)
|
||||
|
||||
# Setup model and dtype
|
||||
mp = model.clone()
|
||||
use_grad_scaler = False
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
else:
|
||||
# Detect model's native dtype for autocast
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
# GradScaler only supports float16 gradients, not bfloat16.
|
||||
# Only enable it when lora params will also be in float16.
|
||||
if lora_dtype != torch.bfloat16:
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||
)
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, latents_dtype, bucket_mode
|
||||
)
|
||||
|
||||
# Validate and expand conditioning
|
||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||
|
||||
with torch.inference_mode(False):
|
||||
# Setup model and dtype
|
||||
mp = model.clone(force_deepcopy=True)
|
||||
use_grad_scaler = False
|
||||
lora_dtype = node_helpers.string_to_torch_dtype(lora_dtype)
|
||||
if training_dtype != "none":
|
||||
dtype = node_helpers.string_to_torch_dtype(training_dtype)
|
||||
mp.set_model_compute_dtype(dtype)
|
||||
else:
|
||||
# Detect model's native dtype for autocast
|
||||
model_dtype = mp.model.get_dtype()
|
||||
if model_dtype == torch.float16:
|
||||
dtype = torch.float16
|
||||
# GradScaler only supports float16 gradients, not bfloat16.
|
||||
# Only enable it when lora params will also be in float16.
|
||||
if lora_dtype != torch.bfloat16:
|
||||
use_grad_scaler = True
|
||||
# Warn about fp16 accumulation instability during training
|
||||
if PerformanceFeature.Fp16Accumulation in args.fast:
|
||||
logging.warning(
|
||||
"WARNING: FP16 model detected with fp16_accumulation enabled. "
|
||||
"This combination can be numerically unstable during training and may cause NaN values. "
|
||||
"Suggested fixes: 1) Set training_dtype to 'bf16', or 2) Disable fp16_accumulation (remove from --fast flags)."
|
||||
)
|
||||
else:
|
||||
# For fp8, bf16, or other dtypes, use bf16 autocast
|
||||
dtype = torch.bfloat16
|
||||
|
||||
# Prepare latents and compute counts
|
||||
latents_dtype = dtype if dtype not in (None,) else torch.bfloat16
|
||||
latents, num_images, multi_res = _prepare_latents_and_count(
|
||||
latents, latents_dtype, bucket_mode
|
||||
)
|
||||
|
||||
# Validate and expand conditioning
|
||||
positive = _validate_and_expand_conditioning(positive, num_images, bucket_mode)
|
||||
|
||||
# Setup models for training
|
||||
mp.model.requires_grad_(False)
|
||||
mp.model.requires_grad_(False).train()
|
||||
|
||||
# Load existing LoRA weights if provided
|
||||
existing_weights, existing_steps = _load_existing_lora(existing_lora)
|
||||
|
||||
@ -134,6 +134,17 @@ class CreateVideo(io.ComfyNode):
|
||||
io.Image.Input("images", tooltip="The images to create a video from."),
|
||||
io.Float.Input("fps", default=30.0, min=1.0, max=120.0, step=1.0),
|
||||
io.Audio.Input("audio", optional=True, tooltip="The audio to add to the video."),
|
||||
io.Int.Input(
|
||||
"bit_depth",
|
||||
min=8,
|
||||
max=10,
|
||||
default=8,
|
||||
step=2,
|
||||
tooltip="Bit depth of the created video. 10-bit keeps smoother gradients with less"
|
||||
" banding, but some players and downstream nodes may not support it.",
|
||||
optional=True,
|
||||
display_mode=io.NumberDisplay.number,
|
||||
),
|
||||
],
|
||||
outputs=[
|
||||
io.Video.Output(),
|
||||
@ -141,9 +152,14 @@ class CreateVideo(io.ComfyNode):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None) -> io.NodeOutput:
|
||||
def execute(
|
||||
cls, images: Input.Image, fps: float, audio: Optional[Input.Audio] = None, bit_depth: int = 8,
|
||||
) -> io.NodeOutput:
|
||||
return io.NodeOutput(
|
||||
InputImpl.VideoFromComponents(Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)))
|
||||
InputImpl.VideoFromComponents(
|
||||
Types.VideoComponents(images=images, audio=audio, frame_rate=Fraction(fps)),
|
||||
bit_depth=bit_depth,
|
||||
)
|
||||
)
|
||||
|
||||
class GetVideoComponents(io.ComfyNode):
|
||||
@ -154,7 +170,7 @@ class GetVideoComponents(io.ComfyNode):
|
||||
search_aliases=["extract frames", "split video", "video to images", "demux"],
|
||||
display_name="Get Video Components",
|
||||
category="video",
|
||||
description="Extracts all components from a video: frames, audio, and framerate.",
|
||||
description="Extracts all components from a video: frames, audio, framerate, and bit depth.",
|
||||
inputs=[
|
||||
io.Video.Input("video", tooltip="The video to extract components from."),
|
||||
],
|
||||
@ -162,13 +178,14 @@ class GetVideoComponents(io.ComfyNode):
|
||||
io.Image.Output(display_name="images"),
|
||||
io.Audio.Output(display_name="audio"),
|
||||
io.Float.Output(display_name="fps"),
|
||||
io.Int.Output(display_name="bit_depth"),
|
||||
],
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, video: Input.Video) -> io.NodeOutput:
|
||||
components = video.get_components()
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate))
|
||||
return io.NodeOutput(components.images, components.audio, float(components.frame_rate), video.get_bit_depth())
|
||||
|
||||
|
||||
class LoadVideo(io.ComfyNode):
|
||||
|
||||
@ -1456,63 +1456,6 @@ class WanInfiniteTalkToVideo(io.ComfyNode):
|
||||
return io.NodeOutput(model_patched, positive, negative, out_latent, trim_image)
|
||||
|
||||
|
||||
class WanSCAILToVideo(io.ComfyNode):
|
||||
@classmethod
|
||||
def define_schema(cls):
|
||||
return io.Schema(
|
||||
node_id="WanSCAILToVideo",
|
||||
category="model/conditioning/video_models",
|
||||
inputs=[
|
||||
io.Conditioning.Input("positive"),
|
||||
io.Conditioning.Input("negative"),
|
||||
io.Vae.Input("vae"),
|
||||
io.Int.Input("width", default=512, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("height", default=896, min=32, max=nodes.MAX_RESOLUTION, step=32),
|
||||
io.Int.Input("length", default=81, min=1, max=nodes.MAX_RESOLUTION, step=4),
|
||||
io.Int.Input("batch_size", default=1, min=1, max=4096),
|
||||
io.ClipVisionOutput.Input("clip_vision_output", optional=True),
|
||||
io.Image.Input("reference_image", optional=True),
|
||||
io.Image.Input("pose_video", optional=True, tooltip="Video used for pose conditioning. Will be downscaled to half the resolution of the main video."),
|
||||
io.Float.Input("pose_strength", default=1.0, min=0.0, max=10.0, step=0.01, tooltip="Strength of the pose latent."),
|
||||
io.Float.Input("pose_start", default=0.0, min=0.0, max=1.0, step=0.01, tooltip="Start step to use pose conditioning."),
|
||||
io.Float.Input("pose_end", default=1.0, min=0.0, max=1.0, step=0.01, tooltip="End step to use pose conditioning."),
|
||||
],
|
||||
outputs=[
|
||||
io.Conditioning.Output(display_name="positive"),
|
||||
io.Conditioning.Output(display_name="negative"),
|
||||
io.Latent.Output(display_name="latent", tooltip="Empty latent of the generation size."),
|
||||
],
|
||||
is_experimental=True,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def execute(cls, positive, negative, vae, width, height, length, batch_size, pose_strength, pose_start, pose_end, reference_image=None, clip_vision_output=None, pose_video=None) -> io.NodeOutput:
|
||||
latent = torch.zeros([batch_size, 16, ((length - 1) // 4) + 1, height // 8, width // 8], device=comfy.model_management.intermediate_device())
|
||||
|
||||
ref_latent = None
|
||||
if reference_image is not None:
|
||||
reference_image = comfy.utils.common_upscale(reference_image[:1].movedim(-1, 1), width, height, "bilinear", "center").movedim(1, -1)
|
||||
ref_latent = vae.encode(reference_image[:, :, :, :3])
|
||||
|
||||
if ref_latent is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"reference_latents": [ref_latent]}, append=True)
|
||||
negative = node_helpers.conditioning_set_values(negative, {"reference_latents": [torch.zeros_like(ref_latent)]}, append=True)
|
||||
|
||||
if clip_vision_output is not None:
|
||||
positive = node_helpers.conditioning_set_values(positive, {"clip_vision_output": clip_vision_output})
|
||||
negative = node_helpers.conditioning_set_values(negative, {"clip_vision_output": clip_vision_output})
|
||||
|
||||
if pose_video is not None:
|
||||
pose_video = comfy.utils.common_upscale(pose_video[:length].movedim(-1, 1), width // 2, height // 2, "area", "center").movedim(1, -1)
|
||||
pose_video_latent = vae.encode(pose_video[:, :, :, :3]) * pose_strength
|
||||
positive = node_helpers.conditioning_set_values_with_timestep_range(positive, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
negative = node_helpers.conditioning_set_values_with_timestep_range(negative, {"pose_video_latent": pose_video_latent}, pose_start, pose_end)
|
||||
|
||||
out_latent = {}
|
||||
out_latent["samples"] = latent
|
||||
return io.NodeOutput(positive, negative, out_latent)
|
||||
|
||||
|
||||
class WanExtension(ComfyExtension):
|
||||
@override
|
||||
async def get_node_list(self) -> list[type[io.ComfyNode]]:
|
||||
@ -1533,7 +1476,6 @@ class WanExtension(ComfyExtension):
|
||||
WanAnimateToVideo,
|
||||
Wan22ImageToVideoLatent,
|
||||
WanInfiniteTalkToVideo,
|
||||
WanSCAILToVideo,
|
||||
]
|
||||
|
||||
async def comfy_entrypoint() -> WanExtension:
|
||||
|
||||
@ -2,6 +2,7 @@ import os
|
||||
import importlib.util
|
||||
from comfy.cli_args import args, PerformanceFeature
|
||||
import subprocess
|
||||
import re
|
||||
|
||||
#Can't use pytorch to get the GPU names because the cuda malloc has to be set before the first import.
|
||||
def get_gpu_names():
|
||||
@ -77,11 +78,24 @@ try:
|
||||
except:
|
||||
pass
|
||||
|
||||
def get_raw_cuda_version(version_str):
|
||||
match = re.search(r'\+cu(\d+)', version_str)
|
||||
if match:
|
||||
try:
|
||||
return int(match.group(1))
|
||||
except:
|
||||
pass
|
||||
return None
|
||||
|
||||
if not args.cuda_malloc:
|
||||
try:
|
||||
if int(version[0]) >= 2 and "+cu" in version: # enable by default for torch version 2.0 and up only on cuda torch
|
||||
if PerformanceFeature.AutoTune not in args.fast: # Autotune has issues with cuda malloc
|
||||
args.cuda_malloc = cuda_malloc_supported()
|
||||
cuda_version = get_raw_cuda_version(version)
|
||||
if cuda_version is not None and cuda_version >= 130:
|
||||
args.cuda_malloc = True
|
||||
else:
|
||||
args.cuda_malloc = cuda_malloc_supported()
|
||||
except:
|
||||
pass
|
||||
|
||||
|
||||
10
execution.py
10
execution.py
@ -40,6 +40,7 @@ from comfy_execution.graph_utils import GraphBuilder, is_link
|
||||
from comfy_execution.validation import validate_node_input
|
||||
from comfy_execution.progress import get_progress_state, reset_progress_state, add_progress_handler, WebUIProgressHandler
|
||||
from comfy_execution.utils import CurrentNodeContext
|
||||
from comfy_execution.asset_enrichment import enrich_output_with_assets
|
||||
from comfy_api.internal import _ComfyNodeInternal, _NodeOutputInternal, first_real_override, is_class, make_locked_method_func
|
||||
from comfy_api.latest import io, _io
|
||||
from comfy_execution.cache_provider import _has_cache_providers, _get_cache_providers, _logger as _cache_logger
|
||||
@ -199,6 +200,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
hidden_inputs_v3[io.Hidden.auth_token_comfy_org] = extra_data.get("auth_token_comfy_org", None)
|
||||
if io.Hidden.api_key_comfy_org.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.api_key_comfy_org] = extra_data.get("api_key_comfy_org", None)
|
||||
if io.Hidden.comfy_usage_source.name in hidden:
|
||||
hidden_inputs_v3[io.Hidden.comfy_usage_source] = extra_data.get("comfy_usage_source", None)
|
||||
else:
|
||||
if "hidden" in valid_inputs:
|
||||
h = valid_inputs["hidden"]
|
||||
@ -215,6 +218,8 @@ def get_input_data(inputs, class_def, unique_id, execution_list=None, dynprompt=
|
||||
input_data_all[x] = [extra_data.get("auth_token_comfy_org", None)]
|
||||
if h[x] == "API_KEY_COMFY_ORG":
|
||||
input_data_all[x] = [extra_data.get("api_key_comfy_org", None)]
|
||||
if h[x] == "COMFY_USAGE_SOURCE":
|
||||
input_data_all[x] = [extra_data.get("comfy_usage_source", None)]
|
||||
v3_data["hidden_inputs"] = hidden_inputs_v3
|
||||
return input_data_all, missing_keys, v3_data
|
||||
|
||||
@ -418,6 +423,7 @@ def _is_intermediate_output(dynprompt, node_id):
|
||||
class_def = nodes.NODE_CLASS_MAPPINGS[class_type]
|
||||
return getattr(class_def, 'HAS_INTERMEDIATE_OUTPUT', False)
|
||||
|
||||
|
||||
def _send_cached_ui(server, node_id, display_node_id, cached, prompt_id, ui_outputs):
|
||||
if server.client_id is None:
|
||||
return
|
||||
@ -552,6 +558,10 @@ async def execute(server, dynprompt, caches, current_item, extra_data, executed,
|
||||
asyncio.create_task(await_completion())
|
||||
return (ExecutionResult.PENDING, None, None)
|
||||
if len(output_ui) > 0:
|
||||
# Enrich at output-processing time (not in the send path) so assets
|
||||
# are registered even when no client is connected, and the asset id
|
||||
# flows into ui_outputs and the cache alongside the raw entries.
|
||||
output_ui = enrich_output_with_assets(output_ui)
|
||||
ui_outputs[unique_id] = {
|
||||
"meta": {
|
||||
"node_id": unique_id,
|
||||
|
||||
20
main.py
20
main.py
@ -26,6 +26,7 @@ import utils.extra_config
|
||||
from utils.mime_types import init_mime_types
|
||||
import faulthandler
|
||||
import logging
|
||||
import signal
|
||||
import sys
|
||||
from comfy_execution.progress import get_progress_state
|
||||
from comfy_execution.utils import get_executing_context
|
||||
@ -37,7 +38,19 @@ if __name__ == "__main__":
|
||||
os.environ['HF_HUB_DISABLE_TELEMETRY'] = '1'
|
||||
os.environ['DO_NOT_TRACK'] = '1'
|
||||
|
||||
faulthandler.enable(file=sys.stderr, all_threads=False)
|
||||
faulthandler.enable(file=sys.stderr, all_threads=args.debug_hang)
|
||||
if __name__ == "__main__" and args.debug_hang:
|
||||
dumping_traceback = False
|
||||
|
||||
def dump_traceback_on_sigint(signum, frame):
|
||||
global dumping_traceback
|
||||
if dumping_traceback:
|
||||
raise KeyboardInterrupt
|
||||
dumping_traceback = True
|
||||
faulthandler.dump_traceback(file=sys.stderr, all_threads=True)
|
||||
raise KeyboardInterrupt
|
||||
|
||||
signal.signal(signal.SIGINT, dump_traceback_on_sigint)
|
||||
|
||||
import comfy_aimdo.control
|
||||
|
||||
@ -477,6 +490,11 @@ def start_comfyui(asyncio_loop=None):
|
||||
init_custom_nodes=(not args.disable_all_custom_nodes) or len(args.whitelist_custom_nodes) > 0,
|
||||
init_api_nodes=not args.disable_api_nodes
|
||||
))
|
||||
|
||||
# Re-apply Comfy's cuDNN benchmark policy after custom-node imports. Benchmark
|
||||
# mode can request near-card-sized autotune workspaces, and some custom nodes set it at import time.
|
||||
comfy.model_management.set_cudnn_benchmark()
|
||||
|
||||
hook_breaker_ac10a0.restore_functions()
|
||||
|
||||
cuda_malloc_warning()
|
||||
|
||||
@ -1 +1 @@
|
||||
comfyui_manager==4.2.1
|
||||
comfyui_manager==4.2.2
|
||||
|
||||
5
nodes.py
5
nodes.py
@ -2404,6 +2404,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_video.py",
|
||||
"nodes_lumina2.py",
|
||||
"nodes_wan.py",
|
||||
"nodes_bernini.py",
|
||||
"nodes_lotus.py",
|
||||
"nodes_hunyuan3d.py",
|
||||
"nodes_primitive.py",
|
||||
@ -2450,6 +2451,7 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_rtdetr.py",
|
||||
"nodes_frame_interpolation.py",
|
||||
"nodes_sam3.py",
|
||||
"nodes_scail.py",
|
||||
"nodes_void.py",
|
||||
"nodes_wandancer.py",
|
||||
"nodes_hidream_o1.py",
|
||||
@ -2457,7 +2459,8 @@ async def init_builtin_extra_nodes():
|
||||
"nodes_moge.py",
|
||||
"nodes_mediapipe.py",
|
||||
"nodes_gaussian_splat.py",
|
||||
"nodes_triposplat.py"
|
||||
"nodes_triposplat.py",
|
||||
"nodes_depth_anything_3.py",
|
||||
]
|
||||
|
||||
import_failed = []
|
||||
|
||||
12
openapi.yaml
12
openapi.yaml
@ -896,6 +896,11 @@ components:
|
||||
additionalProperties: true
|
||||
description: The workflow graph to execute
|
||||
type: object
|
||||
prompt_id:
|
||||
description: Optional client-supplied job id. Must be a UUID in canonical lowercase hyphenated form; it is echoed back in the response. Omitted or null means the server generates one.
|
||||
format: uuid
|
||||
nullable: true
|
||||
type: string
|
||||
workflow_id:
|
||||
description: UUID identifying the cloud workflow entity to associate with this job
|
||||
type: string
|
||||
@ -1062,6 +1067,9 @@ components:
|
||||
comfyui_version:
|
||||
description: ComfyUI version
|
||||
type: string
|
||||
deploy_environment:
|
||||
description: How this ComfyUI instance is deployed (e.g. cloud, local-git, local-portable, local-desktop)
|
||||
type: string
|
||||
embedded_python:
|
||||
description: Whether using embedded Python
|
||||
type: boolean
|
||||
@ -1960,8 +1968,8 @@ paths:
|
||||
schema:
|
||||
properties:
|
||||
hash:
|
||||
description: Hash of the existing asset. Supports Blake3 (blake3:) or SHA256 (sha256:) formats
|
||||
pattern: ^(blake3|sha256):[a-f0-9]{64}$
|
||||
description: 'Blake3 content hash of the existing asset (blake3: prefix)'
|
||||
pattern: ^blake3:[a-f0-9]{64}$
|
||||
type: string
|
||||
mime_type:
|
||||
description: MIME type of the asset (e.g., "image/png", "video/mp4")
|
||||
|
||||
34
server.py
34
server.py
@ -8,7 +8,7 @@ import time
|
||||
import nodes
|
||||
import folder_paths
|
||||
import execution
|
||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs
|
||||
from comfy_execution.jobs import JobStatus, get_job, get_all_jobs, validate_job_id
|
||||
import uuid
|
||||
import urllib
|
||||
import json
|
||||
@ -27,6 +27,7 @@ import logging
|
||||
|
||||
import mimetypes
|
||||
from comfy.cli_args import args
|
||||
from comfy.deploy_environment import get_deploy_environment
|
||||
import comfy.utils
|
||||
import comfy.model_management
|
||||
from comfy_api import feature_flags
|
||||
@ -690,6 +691,7 @@ class PromptServer():
|
||||
"python_version": sys.version,
|
||||
"pytorch_version": comfy.model_management.torch_version,
|
||||
"embedded_python": os.path.split(os.path.split(sys.executable)[0])[1] == "python_embeded",
|
||||
"deploy_environment": get_deploy_environment(),
|
||||
"argv": sys.argv
|
||||
},
|
||||
"devices": device_entries
|
||||
@ -942,7 +944,21 @@ class PromptServer():
|
||||
|
||||
if "prompt" in json_data:
|
||||
prompt = json_data["prompt"]
|
||||
prompt_id = str(json_data.get("prompt_id", uuid.uuid4()))
|
||||
client_prompt_id = json_data.get("prompt_id")
|
||||
if client_prompt_id is None:
|
||||
# Absent or explicit null: the server mints the id.
|
||||
prompt_id = str(uuid.uuid4())
|
||||
else:
|
||||
try:
|
||||
prompt_id = validate_job_id(client_prompt_id)
|
||||
except ValueError:
|
||||
error = {
|
||||
"type": "invalid_prompt_id",
|
||||
"message": "prompt_id must be a valid UUID",
|
||||
"details": "prompt_id must be a UUID string in canonical lowercase hyphenated form; omit it to let the server generate one",
|
||||
"extra_info": {}
|
||||
}
|
||||
return web.json_response({"error": error, "node_errors": {}}, status=400)
|
||||
|
||||
partial_execution_targets = None
|
||||
if "partial_execution_targets" in json_data:
|
||||
@ -957,6 +973,11 @@ class PromptServer():
|
||||
|
||||
if "client_id" in json_data:
|
||||
extra_data["client_id"] = json_data["client_id"]
|
||||
|
||||
if "comfy_usage_source" not in extra_data:
|
||||
usage_source = request.headers.get("Comfy-Usage-Source")
|
||||
if usage_source:
|
||||
extra_data["comfy_usage_source"] = usage_source
|
||||
if valid[0]:
|
||||
outputs_to_execute = valid[2]
|
||||
sensitive = {}
|
||||
@ -1253,6 +1274,15 @@ class PromptServer():
|
||||
|
||||
if verbose:
|
||||
logging.info("Starting server\n")
|
||||
if args.debug_hang:
|
||||
logging.info(
|
||||
f"{'-' * 80}\n"
|
||||
"ComfyUI has been started in debug-hang mode. Run your workflow as normal up to\n"
|
||||
"the point of the hang or freeze, then use ctrl-C in the cmd or controlling\n"
|
||||
"terminal to dump the python backtraces for debugging. Please attach the extra\n"
|
||||
"debug info to your bug report.\n"
|
||||
f"{'-' * 80}"
|
||||
)
|
||||
for addr in addresses:
|
||||
address = addr[0]
|
||||
port = addr[1]
|
||||
|
||||
@ -6,6 +6,7 @@ import subprocess
|
||||
import sys
|
||||
import tempfile
|
||||
import time
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import Callable, Iterator, Optional
|
||||
|
||||
@ -188,9 +189,17 @@ def _post_multipart_asset(
|
||||
|
||||
@pytest.fixture
|
||||
def make_asset_bytes() -> Callable[[str, int], bytes]:
|
||||
# Salt content per test so it never collides with assets left over from
|
||||
# earlier tests. Delete is now always a soft delete (content is preserved),
|
||||
# so the suite can no longer rely on hard-deleting content for isolation.
|
||||
# Deterministic within a test: the same (name, size) yields the same bytes.
|
||||
salt = uuid.uuid4().bytes
|
||||
|
||||
def _make(name: str, size: int = 8192) -> bytes:
|
||||
seed = sum(ord(c) for c in name) % 251
|
||||
return bytes((i * 31 + seed) % 256 for i in range(size))
|
||||
body = bytearray((i * 31 + seed) % 256 for i in range(size))
|
||||
body[: len(salt)] = salt[:size]
|
||||
return bytes(body)
|
||||
return _make
|
||||
|
||||
|
||||
@ -212,7 +221,7 @@ def asset_factory(http: requests.Session, api_base: str):
|
||||
|
||||
for aid in created:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30)
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@ -227,7 +236,11 @@ def seeded_asset(request: pytest.FixtureRequest, http: requests.Session, api_bas
|
||||
if tags is None:
|
||||
tags = ["models", "checkpoints", "unit-tests", "alpha"]
|
||||
meta = {"purpose": "test", "epoch": 1, "flags": ["x", "y"], "nullable": None}
|
||||
files = {"file": (name, b"A" * 4096, "application/octet-stream")}
|
||||
# Unique content per test so the seed always creates a fresh asset (201).
|
||||
# Delete is now always a soft delete, so content from a prior test survives
|
||||
# and would otherwise dedup this upload into an existing asset (200).
|
||||
content = uuid.uuid4().bytes + b"A" * (4096 - 16)
|
||||
files = {"file": (name, content, "application/octet-stream")}
|
||||
form_data = {
|
||||
"tags": json.dumps(tags),
|
||||
"name": name,
|
||||
@ -260,4 +273,4 @@ def autoclean_unit_test_assets(http: requests.Session, api_base: str):
|
||||
break
|
||||
for aid in ids:
|
||||
with contextlib.suppress(Exception):
|
||||
http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=30)
|
||||
http.delete(f"{api_base}/api/assets/{aid}", timeout=30)
|
||||
|
||||
112
tests-unit/assets_test/queries/test_asset_reference_keyset.py
Normal file
112
tests-unit/assets_test/queries/test_asset_reference_keyset.py
Normal file
@ -0,0 +1,112 @@
|
||||
"""Keyset-pagination tiebreaker tests for list_references_page.
|
||||
|
||||
When multiple rows share the same primary sort value (e.g. four assets
|
||||
created in the same microsecond), the secondary `ORDER BY id` is what keeps
|
||||
keyset pagination from losing or repeating rows. This file exercises that
|
||||
branch directly against an in-memory SQLite session — engineering identical
|
||||
timestamps via HTTP is unreliable enough that we work at the query layer.
|
||||
"""
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.assets.database.models import Asset, AssetReference
|
||||
from app.assets.database.queries.asset_reference import list_references_page
|
||||
|
||||
|
||||
def _make_ref(session: Session, created_at: datetime, name: str, owner: str = "") -> AssetReference:
|
||||
asset = Asset(hash=f"blake3:{uuid.uuid4().hex}", size_bytes=1024)
|
||||
session.add(asset)
|
||||
session.flush()
|
||||
ref = AssetReference(
|
||||
id=str(uuid.uuid4()),
|
||||
asset_id=asset.id,
|
||||
owner_id=owner,
|
||||
name=name,
|
||||
file_path=f"/tmp/{name}",
|
||||
created_at=created_at,
|
||||
updated_at=created_at,
|
||||
last_access_time=created_at,
|
||||
is_missing=False,
|
||||
)
|
||||
session.add(ref)
|
||||
return ref
|
||||
|
||||
|
||||
@pytest.mark.parametrize("order", ["desc", "asc"])
|
||||
def test_tiebreaker_walks_duplicate_sort_values(session: Session, order: str):
|
||||
"""Four rows with the SAME created_at must paginate cleanly under cursor
|
||||
mode — no row dropped, no row repeated, despite the primary sort column
|
||||
being non-discriminating.
|
||||
"""
|
||||
shared_ts = datetime(2024, 5, 20, 12, 0, 0) # naive UTC, like the DB stores
|
||||
refs = [_make_ref(session, shared_ts, f"tie_{i}.png") for i in range(4)]
|
||||
session.commit()
|
||||
|
||||
expected_ids = sorted([r.id for r in refs], reverse=(order == "desc"))
|
||||
|
||||
# Walk the cursor by hand: page size 2, take 3 pages (2 + 2 + 0).
|
||||
seen: list[str] = []
|
||||
after_value = None
|
||||
after_id = None
|
||||
for _ in range(4): # generous loop bound; ought to be 2 iterations
|
||||
page, _tag_map, _total = list_references_page(
|
||||
session,
|
||||
limit=2,
|
||||
sort="created_at",
|
||||
order=order,
|
||||
after_cursor_value=after_value,
|
||||
after_cursor_id=after_id,
|
||||
)
|
||||
if not page:
|
||||
break
|
||||
seen.extend(p.id for p in page)
|
||||
# Use the last row's (created_at, id) as the next cursor input.
|
||||
last = page[-1]
|
||||
after_value, after_id = last.created_at, last.id
|
||||
if len(page) < 2:
|
||||
break
|
||||
|
||||
assert seen == expected_ids, (
|
||||
f"keyset tiebreaker failed for order={order}: expected {expected_ids}, got {seen}"
|
||||
)
|
||||
|
||||
|
||||
def test_tiebreaker_no_duplicates_under_mixed_collisions(session: Session):
|
||||
"""Some rows share a timestamp, some don't. The cursor must still walk
|
||||
every row exactly once regardless of where ties sit relative to a
|
||||
page boundary."""
|
||||
t1 = datetime(2024, 5, 20, 12, 0, 0)
|
||||
t2 = datetime(2024, 5, 20, 12, 0, 1)
|
||||
layout = [t1, t1, t1, t2, t2] # three rows at t1, two at t2
|
||||
refs = [_make_ref(session, ts, f"mix_{i}.png") for i, ts in enumerate(layout)]
|
||||
session.commit()
|
||||
|
||||
all_ids = {r.id for r in refs}
|
||||
seen_set: set[str] = set()
|
||||
seen_list: list[str] = []
|
||||
after_value = None
|
||||
after_id = None
|
||||
for _ in range(6):
|
||||
page, _, _ = list_references_page(
|
||||
session,
|
||||
limit=2,
|
||||
sort="created_at",
|
||||
order="desc",
|
||||
after_cursor_value=after_value,
|
||||
after_cursor_id=after_id,
|
||||
)
|
||||
if not page:
|
||||
break
|
||||
for p in page:
|
||||
assert p.id not in seen_set, f"duplicate row {p.id} appeared in cursor walk"
|
||||
seen_set.add(p.id)
|
||||
seen_list.append(p.id)
|
||||
last = page[-1]
|
||||
after_value, after_id = last.created_at, last.id
|
||||
if len(page) < 2:
|
||||
break
|
||||
|
||||
assert seen_set == all_ids, f"missing rows: expected {all_ids}, got {seen_set}"
|
||||
@ -40,15 +40,15 @@ def _make_reference(session: Session, asset: Asset, name: str = "test", owner_id
|
||||
|
||||
class TestEnsureTagsExist:
|
||||
def test_creates_new_tags(self, session: Session):
|
||||
ensure_tags_exist(session, ["alpha", "beta"], tag_type="user")
|
||||
ensure_tags_exist(session, ["alpha", "beta"])
|
||||
session.commit()
|
||||
|
||||
tags = session.query(Tag).all()
|
||||
assert {t.name for t in tags} == {"alpha", "beta"}
|
||||
|
||||
def test_is_idempotent(self, session: Session):
|
||||
ensure_tags_exist(session, ["alpha"], tag_type="user")
|
||||
ensure_tags_exist(session, ["alpha"], tag_type="user")
|
||||
ensure_tags_exist(session, ["alpha"])
|
||||
ensure_tags_exist(session, ["alpha"])
|
||||
session.commit()
|
||||
|
||||
assert session.query(Tag).count() == 1
|
||||
@ -65,13 +65,6 @@ class TestEnsureTagsExist:
|
||||
session.commit()
|
||||
assert session.query(Tag).count() == 0
|
||||
|
||||
def test_tag_type_is_set(self, session: Session):
|
||||
ensure_tags_exist(session, ["system-tag"], tag_type="system")
|
||||
session.commit()
|
||||
|
||||
tag = session.query(Tag).filter_by(name="system-tag").one()
|
||||
assert tag.tag_type == "system"
|
||||
|
||||
|
||||
class TestGetReferenceTags:
|
||||
def test_returns_empty_for_no_tags(self, session: Session):
|
||||
@ -193,7 +186,7 @@ class TestMissingTagFunctions:
|
||||
def test_add_missing_tag_for_asset_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
ensure_tags_exist(session, ["missing"])
|
||||
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
session.commit()
|
||||
@ -204,7 +197,7 @@ class TestMissingTagFunctions:
|
||||
def test_add_missing_tag_is_idempotent(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
ensure_tags_exist(session, ["missing"])
|
||||
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
@ -216,7 +209,7 @@ class TestMissingTagFunctions:
|
||||
def test_remove_missing_tag_for_asset_id(self, session: Session):
|
||||
asset = _make_asset(session, "hash1")
|
||||
ref = _make_reference(session, asset)
|
||||
ensure_tags_exist(session, ["missing"], tag_type="system")
|
||||
ensure_tags_exist(session, ["missing"])
|
||||
add_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
|
||||
remove_missing_tag_for_asset_id(session, asset_id=asset.id)
|
||||
@ -237,7 +230,7 @@ class TestListTagsWithUsage:
|
||||
|
||||
rows, total = list_tags_with_usage(session)
|
||||
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
tag_dict = {name: count for name, count in rows}
|
||||
assert tag_dict["used"] == 1
|
||||
assert tag_dict["unused"] == 0
|
||||
assert total == 2
|
||||
@ -252,7 +245,7 @@ class TestListTagsWithUsage:
|
||||
|
||||
rows, total = list_tags_with_usage(session, include_zero=False)
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
tag_names = {name for name, _ in rows}
|
||||
assert "used" in tag_names
|
||||
assert "unused" not in tag_names
|
||||
|
||||
@ -262,7 +255,7 @@ class TestListTagsWithUsage:
|
||||
|
||||
rows, total = list_tags_with_usage(session, prefix="alph")
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
tag_names = {name for name, _ in rows}
|
||||
assert tag_names == {"alpha", "alphabet"}
|
||||
|
||||
def test_order_by_name(self, session: Session):
|
||||
@ -271,7 +264,7 @@ class TestListTagsWithUsage:
|
||||
|
||||
rows, _ = list_tags_with_usage(session, order="name_asc")
|
||||
|
||||
names = [name for name, _, _ in rows]
|
||||
names = [name for name, _ in rows]
|
||||
assert names == ["alpha", "middle", "zebra"]
|
||||
|
||||
def test_owner_visibility(self, session: Session):
|
||||
@ -287,13 +280,13 @@ class TestListTagsWithUsage:
|
||||
|
||||
# Empty owner sees only shared
|
||||
rows, _ = list_tags_with_usage(session, owner_id="", include_zero=False)
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
tag_dict = {name: count for name, count in rows}
|
||||
assert tag_dict.get("shared-tag", 0) == 1
|
||||
assert tag_dict.get("owner-tag", 0) == 0
|
||||
|
||||
# User1 sees both
|
||||
rows, _ = list_tags_with_usage(session, owner_id="user1", include_zero=False)
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
tag_dict = {name: count for name, count in rows}
|
||||
assert tag_dict.get("shared-tag", 0) == 1
|
||||
assert tag_dict.get("owner-tag", 0) == 1
|
||||
|
||||
|
||||
278
tests-unit/assets_test/services/test_cursor.py
Normal file
278
tests-unit/assets_test/services/test_cursor.py
Normal file
@ -0,0 +1,278 @@
|
||||
"""Tests for app.assets.services.cursor.
|
||||
|
||||
Cursors are opaque tokens internal to this server — these tests cover
|
||||
round-tripping, validation, and length caps, not any particular wire
|
||||
byte layout.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
from datetime import datetime, timedelta, timezone
|
||||
|
||||
import pytest
|
||||
|
||||
from app.assets.services.cursor import (
|
||||
MAX_CURSOR_ID_LENGTH,
|
||||
MAX_CURSOR_VALUE_LENGTH,
|
||||
MAX_ENCODED_CURSOR_LENGTH,
|
||||
CursorPayload,
|
||||
InvalidCursorError,
|
||||
decode_cursor,
|
||||
decode_cursor_int,
|
||||
decode_cursor_time,
|
||||
encode_cursor,
|
||||
encode_cursor_from_time,
|
||||
)
|
||||
|
||||
|
||||
ALLOWED = ("created_at", "updated_at", "name", "size")
|
||||
|
||||
|
||||
class TestRoundTrip:
|
||||
@pytest.mark.parametrize(
|
||||
"sort_field, value, id",
|
||||
[
|
||||
("created_at", "1716200000000000", "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"),
|
||||
("size", "1024", "asset-123"),
|
||||
("name", "my-asset.png", "asset-abc"),
|
||||
("name", "résumé.txt", "asset-uni"),
|
||||
("name", "foo<&>bar.png", "asset-html"),
|
||||
("name", 'quo"te\\back\nnewline.png', "asset-esc"),
|
||||
],
|
||||
)
|
||||
def test_encode_decode(self, sort_field, value, id):
|
||||
encoded = encode_cursor(sort_field, value, id)
|
||||
assert encoded != ""
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.sort_field == sort_field
|
||||
assert payload.value == value
|
||||
assert payload.id == id
|
||||
|
||||
|
||||
class TestTimeCursor:
|
||||
def test_microsecond_precision_preserved(self):
|
||||
# Pick a time with non-zero microseconds — encoding at ms would lose the µs.
|
||||
ts = datetime(2024, 5, 20, 12, 53, 20, 123456, tzinfo=timezone.utc)
|
||||
encoded = encode_cursor_from_time("created_at", ts, "id-1")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
# Value must be a microsecond integer string, not a millisecond one.
|
||||
assert payload.value == "1716209600123456"
|
||||
decoded = decode_cursor_time(payload)
|
||||
assert decoded == ts
|
||||
|
||||
def test_decode_returns_utc(self):
|
||||
payload = CursorPayload(sort_field="created_at", value="1716200000123456", id="id-1", order="desc")
|
||||
decoded = decode_cursor_time(payload)
|
||||
assert decoded.tzinfo == timezone.utc
|
||||
|
||||
def test_naive_datetime_rejected_on_encode(self):
|
||||
naive = datetime(2024, 5, 20, 12, 0, 0)
|
||||
with pytest.raises(ValueError):
|
||||
encode_cursor_from_time("created_at", naive, "id-1")
|
||||
|
||||
def test_non_integer_value_rejected_on_decode(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(CursorPayload("created_at", "not-a-number", "id-1", "desc"))
|
||||
|
||||
def test_none_payload_rejected(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(None)
|
||||
|
||||
def test_non_utc_aware_normalized(self):
|
||||
# Same instant, different timezone — must encode to the same micros.
|
||||
utc_ts = datetime(2024, 5, 20, 12, 0, 0, tzinfo=timezone.utc)
|
||||
offset_ts = utc_ts.astimezone(timezone(timedelta(hours=-5)))
|
||||
assert encode_cursor_from_time("created_at", utc_ts, "x") == encode_cursor_from_time(
|
||||
"created_at", offset_ts, "x"
|
||||
)
|
||||
|
||||
|
||||
class TestIntCursor:
|
||||
def test_decode_int(self):
|
||||
assert decode_cursor_int(CursorPayload("size", "1024", "id-1", "desc")) == 1024
|
||||
|
||||
def test_decode_int_rejects_non_int(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_int(CursorPayload("size", "abc", "id-1", "desc"))
|
||||
|
||||
def test_decode_int_rejects_none(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_int(None)
|
||||
|
||||
|
||||
class TestInvalidInputs:
|
||||
def test_oversized_cursor(self):
|
||||
oversized = "a" * (MAX_ENCODED_CURSOR_LENGTH + 1)
|
||||
with pytest.raises(InvalidCursorError, match="maximum length"):
|
||||
decode_cursor(oversized, ALLOWED)
|
||||
|
||||
def test_not_base64(self):
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor("not base64!!!", ALLOWED)
|
||||
|
||||
def test_not_json(self):
|
||||
encoded = base64.urlsafe_b64encode(b"definitely not json").rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_empty_id(self):
|
||||
# Encoder rejects empty id symmetrically with the decoder, so build the
|
||||
# payload manually to exercise the decoder's missing-id branch.
|
||||
raw = b'{"s":"created_at","v":"1","id":"","o":"desc"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="missing id"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_oversized_id(self):
|
||||
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
|
||||
big_id = "a" * (MAX_CURSOR_ID_LENGTH + 1)
|
||||
raw = ('{"s":"created_at","v":"1","id":"' + big_id + '","o":"desc"}').encode("ascii")
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_oversized_value(self):
|
||||
# Encoder enforces the cap symmetrically; hand-build to exercise decode.
|
||||
big_v = "v" * (MAX_CURSOR_VALUE_LENGTH + 1)
|
||||
raw = ('{"s":"created_at","v":"' + big_v + '","id":"id-1","o":"desc"}').encode("ascii")
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_unsupported_sort_field(self):
|
||||
encoded = encode_cursor("execution_time", "1", "id-1")
|
||||
with pytest.raises(InvalidCursorError, match="unsupported sort field"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_no_allowed_fields_rejects_everything(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1")
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor(encoded, ())
|
||||
|
||||
def test_non_dict_payload_rejected(self):
|
||||
encoded = base64.urlsafe_b64encode(b'["array","not","dict"]').rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="expected object"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
|
||||
class TestEncodeAtCapsFits:
|
||||
def test_max_field_lengths_fit_wire_cap(self):
|
||||
# Worst-case payload: value and id at their per-field caps, with a long
|
||||
# sort field name. The encoded cursor must fit within MAX_ENCODED_CURSOR_LENGTH
|
||||
# so the wire cap cannot reject a cursor the encoder mints at the per-field caps.
|
||||
value = "v" * MAX_CURSOR_VALUE_LENGTH
|
||||
id = "i" * MAX_CURSOR_ID_LENGTH
|
||||
sort_field = "very_long_sort_field_name"
|
||||
|
||||
encoded = encode_cursor(sort_field, value, id)
|
||||
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
|
||||
payload = decode_cursor(encoded, (sort_field,))
|
||||
assert payload.value == value
|
||||
assert payload.id == id
|
||||
|
||||
|
||||
class TestDatetimeOverflow:
|
||||
"""Crafted cursors with extreme micros must map to InvalidCursorError,
|
||||
not OverflowError/OSError leaking as 500.
|
||||
"""
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"micros_str",
|
||||
[
|
||||
"999999999999999999999", # 10^21 µs — past datetime.MAX_YEAR by ~14 orders
|
||||
"-999999999999999999999", # symmetric negative — pre-epoch overflow
|
||||
],
|
||||
)
|
||||
def test_out_of_range_micros_rejected(self, micros_str):
|
||||
encoded = encode_cursor("created_at", micros_str, "asset-x")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
with pytest.raises(InvalidCursorError):
|
||||
decode_cursor_time(payload)
|
||||
|
||||
|
||||
class TestEncoderDecoderSymmetry:
|
||||
"""The encoder must never mint a cursor the decoder would reject, or the
|
||||
same server would 400 on a cursor it just handed out. Per-field caps keep
|
||||
the encoded length below the wire cap, so a freshly minted cursor always
|
||||
round-trips.
|
||||
"""
|
||||
|
||||
def test_long_name_within_cap_round_trips(self):
|
||||
"""Assets allow names up to 512 chars (`String(512)`); the cursor
|
||||
encoder must round-trip a value at that cap so a freshly minted
|
||||
cursor never fails decode on the next request."""
|
||||
long_name = "n" * MAX_CURSOR_VALUE_LENGTH
|
||||
encoded = encode_cursor("name", long_name, "asset-x")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.value == long_name
|
||||
|
||||
def test_encoder_rejects_empty_id(self):
|
||||
with pytest.raises(InvalidCursorError, match="id must be non-empty"):
|
||||
encode_cursor("created_at", "1", "")
|
||||
|
||||
def test_encoder_rejects_oversized_id(self):
|
||||
with pytest.raises(InvalidCursorError, match="id exceeds maximum length"):
|
||||
encode_cursor("created_at", "1", "a" * (MAX_CURSOR_ID_LENGTH + 1))
|
||||
|
||||
def test_encoder_rejects_oversized_value(self):
|
||||
with pytest.raises(InvalidCursorError, match="value exceeds maximum length"):
|
||||
encode_cursor("name", "v" * (MAX_CURSOR_VALUE_LENGTH + 1), "id-1")
|
||||
|
||||
def test_multibyte_value_at_cap_round_trips(self):
|
||||
"""A value at the char-count cap made of multibyte characters
|
||||
(e.g. 'é' = 2 UTF-8 bytes) stays under the wire cap, so it mints and
|
||||
round-trips — the per-field caps, not a mint-time length check, are
|
||||
what bound cursor size."""
|
||||
value = "é" * MAX_CURSOR_VALUE_LENGTH
|
||||
encoded = encode_cursor("name", value, "asset-multibyte")
|
||||
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.value == value
|
||||
|
||||
def test_escape_heavy_value_at_cap_round_trips(self):
|
||||
"""JSON escape expansion is the worst case: each control character
|
||||
serializes to the six-byte `\\uXXXX` form. A value of 512 of them is
|
||||
the largest a cursor can get, and it still fits the wire cap, mints,
|
||||
and round-trips."""
|
||||
value = "\x01" * MAX_CURSOR_VALUE_LENGTH
|
||||
encoded = encode_cursor("name", value, "asset-escape")
|
||||
assert len(encoded) <= MAX_ENCODED_CURSOR_LENGTH
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.value == value
|
||||
|
||||
|
||||
class TestOrderBinding:
|
||||
def test_order_baked_into_payload(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="asc")
|
||||
payload = decode_cursor(encoded, ALLOWED)
|
||||
assert payload.order == "asc"
|
||||
|
||||
def test_mismatched_order_rejected(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
|
||||
with pytest.raises(InvalidCursorError, match="does not match request order"):
|
||||
decode_cursor(encoded, ALLOWED, expected_order="asc")
|
||||
|
||||
def test_matching_order_accepted(self):
|
||||
encoded = encode_cursor("created_at", "1", "id-1", order="desc")
|
||||
payload = decode_cursor(encoded, ALLOWED, expected_order="desc")
|
||||
assert payload.order == "desc"
|
||||
|
||||
def test_invalid_order_token_rejected_on_encode(self):
|
||||
with pytest.raises(ValueError):
|
||||
encode_cursor("created_at", "1", "id-1", order="sideways")
|
||||
|
||||
def test_invalid_order_token_rejected_on_decode(self):
|
||||
# Hand-craft a payload with an illegal `o` value.
|
||||
raw = b'{"s":"name","v":"x","id":"id-1","o":"sideways"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="unsupported order"):
|
||||
decode_cursor(encoded, ALLOWED)
|
||||
|
||||
def test_cursor_without_order_rejected(self):
|
||||
"""`o` is mandatory. A cursor minted without it is rejected as
|
||||
malformed rather than silently walking the keyset in whatever
|
||||
direction the request happens to ask for."""
|
||||
raw = b'{"s":"name","v":"x","id":"id-1"}'
|
||||
encoded = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
with pytest.raises(InvalidCursorError, match="missing or non-string o"):
|
||||
decode_cursor(encoded, ALLOWED, expected_order="desc")
|
||||
@ -141,7 +141,7 @@ class TestListTags:
|
||||
|
||||
rows, total = list_tags()
|
||||
|
||||
tag_dict = {name: count for name, _, count in rows}
|
||||
tag_dict = {name: count for name, count in rows}
|
||||
assert tag_dict["used"] == 1
|
||||
assert tag_dict["unused"] == 0
|
||||
assert total == 2
|
||||
@ -155,7 +155,7 @@ class TestListTags:
|
||||
|
||||
rows, total = list_tags(include_zero=False)
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
tag_names = {name for name, _ in rows}
|
||||
assert "used" in tag_names
|
||||
assert "unused" not in tag_names
|
||||
|
||||
@ -165,7 +165,7 @@ class TestListTags:
|
||||
|
||||
rows, _ = list_tags(prefix="alph")
|
||||
|
||||
tag_names = {name for name, _, _ in rows}
|
||||
tag_names = {name for name, _ in rows}
|
||||
assert tag_names == {"alpha", "alphabet"}
|
||||
|
||||
def test_order_by_name(self, mock_create_session, session: Session):
|
||||
@ -174,7 +174,7 @@ class TestListTags:
|
||||
|
||||
rows, _ = list_tags(order="name_asc")
|
||||
|
||||
names = [name for name, _, _ in rows]
|
||||
names = [name for name, _ in rows]
|
||||
assert names == ["alpha", "middle", "zebra"]
|
||||
|
||||
def test_pagination(self, mock_create_session, session: Session):
|
||||
@ -185,7 +185,7 @@ class TestListTags:
|
||||
|
||||
assert total == 5
|
||||
assert len(rows) == 2
|
||||
names = [name for name, _, _ in rows]
|
||||
names = [name for name, _ in rows]
|
||||
assert names == ["b", "c"]
|
||||
|
||||
def test_clamps_limit(self, mock_create_session, session: Session):
|
||||
|
||||
@ -45,8 +45,8 @@ def test_get_and_delete_asset(http: requests.Session, api_base: str, seeded_asse
|
||||
assert "user_metadata" in detail
|
||||
assert "filename" in detail["user_metadata"]
|
||||
|
||||
# DELETE (hard delete to also remove underlying asset and file)
|
||||
rd = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120)
|
||||
# Soft delete — the reference is hidden, content is preserved
|
||||
rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
assert rd.status_code == 204
|
||||
|
||||
# GET again -> 404
|
||||
@ -60,7 +60,7 @@ def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seede
|
||||
aid = seeded_asset["id"]
|
||||
asset_hash = seeded_asset["asset_hash"]
|
||||
|
||||
# Soft-delete (default, no delete_content param)
|
||||
# Soft delete — the reference is hidden, content is preserved
|
||||
rd = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
assert rd.status_code == 204
|
||||
|
||||
@ -81,11 +81,10 @@ def test_soft_delete_hides_from_get(http: requests.Session, api_base: str, seede
|
||||
ids = [a["id"] for a in rl.json().get("assets", [])]
|
||||
assert aid not in ids
|
||||
|
||||
# Clean up: hard-delete the soft-deleted reference and orphaned asset
|
||||
http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120)
|
||||
# The reference is already soft-deleted; content is preserved.
|
||||
|
||||
|
||||
def test_delete_upon_reference_count(
|
||||
def test_soft_delete_preserves_asset_identity_across_references(
|
||||
http: requests.Session, api_base: str, seeded_asset: dict
|
||||
):
|
||||
# Create a second reference to the same asset via from-hash
|
||||
@ -119,16 +118,20 @@ def test_delete_upon_reference_count(
|
||||
rh2 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
|
||||
assert rh2.status_code == 200 # asset identity preserved (soft delete)
|
||||
|
||||
# Re-associate via from-hash, then hard-delete -> orphan content removed
|
||||
# Re-associate via from-hash: it must reuse the same preserved content
|
||||
# (created_new False AND the same hash), proving the soft deletes did not
|
||||
# destroy the underlying asset. Then soft-delete again -> still preserved.
|
||||
r3 = http.post(f"{api_base}/api/assets/from-hash", json=payload, timeout=120)
|
||||
assert r3.status_code == 201, r3.json()
|
||||
assert r3.json()["created_new"] is False
|
||||
assert r3.json()["asset_hash"] == src_hash # reused the surviving content
|
||||
aid3 = r3.json()["id"]
|
||||
|
||||
rd3 = http.delete(f"{api_base}/api/assets/{aid3}?delete_content=true", timeout=120)
|
||||
rd3 = http.delete(f"{api_base}/api/assets/{aid3}", timeout=120)
|
||||
assert rd3.status_code == 204
|
||||
|
||||
rh3 = http.head(f"{api_base}/api/assets/hash/{src_hash}", timeout=120)
|
||||
assert rh3.status_code == 404 # orphan content removed
|
||||
assert rh3.status_code == 200 # content preserved (soft delete)
|
||||
|
||||
|
||||
def test_update_asset_fields(http: requests.Session, api_base: str, seeded_asset: dict):
|
||||
@ -249,7 +252,7 @@ def test_concurrent_delete_same_asset_info_single_204(
|
||||
|
||||
# Hit the same endpoint N times in parallel.
|
||||
n_tests = 4
|
||||
url = f"{api_base}/api/assets/{aid}?delete_content=false"
|
||||
url = f"{api_base}/api/assets/{aid}"
|
||||
|
||||
def _do_delete(delete_url):
|
||||
with requests.Session() as s:
|
||||
|
||||
@ -117,7 +117,7 @@ def test_download_missing_file_returns_404(
|
||||
assert body["error"]["code"] == "FILE_NOT_FOUND"
|
||||
finally:
|
||||
# We created asset without the "unit-tests" tag(see `autoclean_unit_test_assets`), we need to clear it manually.
|
||||
dr = http.delete(f"{api_base}/api/assets/{aid}?delete_content=true", timeout=120)
|
||||
dr = http.delete(f"{api_base}/api/assets/{aid}", timeout=120)
|
||||
dr.content
|
||||
|
||||
|
||||
|
||||
349
tests-unit/assets_test/test_list_cursor.py
Normal file
349
tests-unit/assets_test/test_list_cursor.py
Normal file
@ -0,0 +1,349 @@
|
||||
"""Integration tests for cursor-based pagination on GET /api/assets.
|
||||
|
||||
These tests exercise the handler/service/query path end-to-end;
|
||||
cursor-encoding-level tests live in
|
||||
tests-unit/assets_test/services/test_cursor.py.
|
||||
"""
|
||||
import pytest
|
||||
import requests
|
||||
|
||||
|
||||
def _seed(asset_factory, make_asset_bytes, count: int, tag: str) -> list[str]:
|
||||
names = [f"cursor_{i:02d}.safetensors" for i in range(count)]
|
||||
for n in names:
|
||||
asset_factory(
|
||||
n,
|
||||
["models", "checkpoints", "unit-tests", tag],
|
||||
{},
|
||||
make_asset_bytes(n, size=2048),
|
||||
)
|
||||
return sorted(names)
|
||||
|
||||
|
||||
def test_cursor_pages_all_items_in_order(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=5, tag="cursor-walk")
|
||||
|
||||
params = {
|
||||
"include_tags": "unit-tests,cursor-walk",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
}
|
||||
|
||||
seen: list[str] = []
|
||||
after: str | None = None
|
||||
pages = 0
|
||||
while True:
|
||||
page_params = dict(params)
|
||||
if after is not None:
|
||||
page_params["after"] = after
|
||||
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
seen.extend(a["name"] for a in body["assets"])
|
||||
pages += 1
|
||||
after = body.get("next_cursor")
|
||||
if after is None:
|
||||
break
|
||||
assert body["has_more"] is True
|
||||
assert pages < 10, "guard against runaway cursor loop"
|
||||
|
||||
assert seen == names, f"expected {names}, got {seen}"
|
||||
# Last page should have has_more False
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
def test_cursor_invalid_returns_400(http: requests.Session, api_base: str):
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": "not-a-real-cursor", "sort": "created_at"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
body = r.json()
|
||||
assert body["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_sort_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-mismatch")
|
||||
|
||||
# Take a real cursor minted for sort=name.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-mismatch",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Replay against sort=created_at — should fail with INVALID_CURSOR.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": cursor, "sort": "created_at"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 400, r2.text
|
||||
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_wins_over_offset(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-vs-offset")
|
||||
|
||||
# Take a cursor that points past the first item.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-vs-offset",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Pass both 'after' and a large offset. Cursor must win; offset is ignored.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-vs-offset",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
"after": cursor,
|
||||
"offset": "999",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200
|
||||
body = r2.json()
|
||||
# Should land on the second name in sorted order — not skip ahead by 999.
|
||||
assert [a["name"] for a in body["assets"]] == [names[1]]
|
||||
|
||||
|
||||
def test_next_cursor_absent_when_no_more_results(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
_seed(asset_factory, make_asset_bytes, count=2, tag="cursor-exhaust")
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-exhaust",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "50",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
def test_cursor_pagination_first_page_mints_cursor(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""First-page request (no `after`) must still return `next_cursor` when
|
||||
more rows exist, or pagination is unreachable from a cold start.
|
||||
"""
|
||||
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-first-page")
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-first-page", "sort": "name", "order": "asc", "limit": "2"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
assert body["has_more"] is True
|
||||
assert body.get("next_cursor"), "first page must mint a cursor when more rows exist"
|
||||
|
||||
|
||||
def test_cursor_no_spurious_cursor_when_page_size_equals_remainder(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""When `total` is an exact multiple of `limit`, the final page must
|
||||
NOT carry a next_cursor — there is nothing past it.
|
||||
"""
|
||||
_seed(asset_factory, make_asset_bytes, count=4, tag="cursor-exact-multiple")
|
||||
# Page 1
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
# Page 2 — should exhaust the set with no cursor for a phantom page 3
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"include_tags": "unit-tests,cursor-exact-multiple", "sort": "name", "order": "asc", "limit": "2", "after": cursor},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200, r2.text
|
||||
body = r2.json()
|
||||
assert len(body["assets"]) == 2
|
||||
assert body["has_more"] is False
|
||||
assert "next_cursor" not in body
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sort_field", ["created_at", "updated_at", "size"])
|
||||
def test_cursor_walks_for_non_name_sorts(sort_field, http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""Cursor pagination must work for every sort field the contract claims.
|
||||
|
||||
Without this, the `created_at` / `updated_at` (time-encoded micros) and
|
||||
`size` (int-encoded) cursor paths go entirely unexercised end-to-end.
|
||||
"""
|
||||
# Sizes increase strictly by index, so `size desc` has a deterministic
|
||||
# expected order. Time-based sorts (created_at / updated_at) can tie when
|
||||
# rows are inserted faster than the DB's timestamp resolution; for those
|
||||
# we check coverage and no-duplicates and let the keyset tiebreaker do
|
||||
# the rest, instead of sleeping between inserts and asserting an order
|
||||
# that depends on clock granularity.
|
||||
names = []
|
||||
for i in range(4):
|
||||
n = f"cursor_{sort_field}_{i:02d}.safetensors"
|
||||
asset_factory(n, ["models", "checkpoints", "unit-tests", f"cursor-{sort_field}"], {}, make_asset_bytes(n, size=2048 + i))
|
||||
names.append(n)
|
||||
|
||||
params = {
|
||||
"include_tags": f"unit-tests,cursor-{sort_field}",
|
||||
"sort": sort_field,
|
||||
"order": "desc",
|
||||
"limit": "2",
|
||||
}
|
||||
seen: list[str] = []
|
||||
after: str | None = None
|
||||
pages = 0
|
||||
while True:
|
||||
page_params = dict(params)
|
||||
if after is not None:
|
||||
page_params["after"] = after
|
||||
r = http.get(api_base + "/api/assets", params=page_params, timeout=120)
|
||||
assert r.status_code == 200, r.text
|
||||
body = r.json()
|
||||
seen.extend(a["name"] for a in body["assets"])
|
||||
after = body.get("next_cursor")
|
||||
pages += 1
|
||||
if after is None:
|
||||
break
|
||||
assert pages < 10, "guard against runaway cursor loop"
|
||||
|
||||
# No duplicates: a faulty keyset boundary that returns the same row across
|
||||
# two pages must fail this check.
|
||||
assert len(seen) == len(set(seen)), (
|
||||
f"cursor walk repeated rows for sort={sort_field}: {seen}"
|
||||
)
|
||||
# Full coverage: every seeded asset reached exactly once.
|
||||
assert set(seen) == set(names), (
|
||||
f"missing items for sort={sort_field}: expected {set(names)}, got {set(seen)}"
|
||||
)
|
||||
# Strict order check for the only field with a clock-independent ordering.
|
||||
if sort_field == "size":
|
||||
assert seen == list(reversed(names)), (
|
||||
f"size cursor walked out of order: got {seen}, expected {list(reversed(names))}"
|
||||
)
|
||||
|
||||
|
||||
def test_cursor_order_mismatch_returns_400(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
"""A cursor minted under desc order replayed against asc must 400, not
|
||||
silently walk the wrong direction."""
|
||||
_seed(asset_factory, make_asset_bytes, count=3, tag="cursor-order-flip")
|
||||
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-order-flip",
|
||||
"sort": "name",
|
||||
"order": "desc",
|
||||
"limit": "1",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200, r.text
|
||||
cursor = r.json()["next_cursor"]
|
||||
assert cursor is not None
|
||||
|
||||
# Replay with order flipped to asc — server must reject the cursor.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-order-flip",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "1",
|
||||
"after": cursor,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 400, r2.text
|
||||
assert r2.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_invalid_cursor_at_microsecond_boundary(http: requests.Session, api_base: str):
|
||||
"""A cursor carrying an out-of-range microsecond timestamp must map to
|
||||
400 INVALID_CURSOR, not 500."""
|
||||
import base64
|
||||
import json
|
||||
# 10^18 microseconds ≈ year 33658, well past datetime.MAX_YEAR.
|
||||
# `o` and `order=` must be set; otherwise decode fails earlier on the
|
||||
# missing-order branch and the µs-overflow path is never exercised.
|
||||
payload = {"s": "created_at", "o": "desc", "v": "999999999999999999999", "id": "asset-x"}
|
||||
raw = json.dumps(payload, separators=(",", ":")).encode("utf-8")
|
||||
cursor = base64.urlsafe_b64encode(raw).rstrip(b"=").decode("ascii")
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={"after": cursor, "sort": "created_at", "order": "desc"},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
assert r.json()["error"]["code"] == "INVALID_CURSOR"
|
||||
|
||||
|
||||
def test_cursor_pagination_stable_after_delete(http: requests.Session, api_base: str, asset_factory, make_asset_bytes):
|
||||
names = _seed(asset_factory, make_asset_bytes, count=4, tag="cursor-delete")
|
||||
|
||||
# Page 1.
|
||||
r = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-delete",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r.status_code == 200
|
||||
body = r.json()
|
||||
page1_names = [a["name"] for a in body["assets"]]
|
||||
cursor = body["next_cursor"]
|
||||
assert cursor is not None
|
||||
assert page1_names == names[:2]
|
||||
|
||||
# Delete an item from page 1 (already returned) — cursor should still
|
||||
# locate the next page from where it was minted, not re-index.
|
||||
target_id = body["assets"][0]["id"]
|
||||
d = http.delete(api_base + f"/api/assets/{target_id}", timeout=120)
|
||||
assert d.status_code in (200, 204), d.text
|
||||
|
||||
# Page 2 via cursor.
|
||||
r2 = http.get(
|
||||
api_base + "/api/assets",
|
||||
params={
|
||||
"include_tags": "unit-tests,cursor-delete",
|
||||
"sort": "name",
|
||||
"order": "asc",
|
||||
"limit": "2",
|
||||
"after": cursor,
|
||||
},
|
||||
timeout=120,
|
||||
)
|
||||
assert r2.status_code == 200, r2.text
|
||||
body2 = r2.json()
|
||||
assert [a["name"] for a in body2["assets"]] == names[2:]
|
||||
69
tests-unit/assets_test/test_prompt_id_enforcement.py
Normal file
69
tests-unit/assets_test/test_prompt_id_enforcement.py
Normal file
@ -0,0 +1,69 @@
|
||||
"""POST /prompt enforces canonical-UUID job ids at creation time.
|
||||
|
||||
Lives in assets_test because it uses this suite's booted-server fixture. The
|
||||
invariant itself is pipeline-wide: a job id is stored and compared verbatim
|
||||
downstream — history keys, websocket correlation, and /interrupt matching —
|
||||
so a job minted with a non-canonical id would miss every exact-match lookup.
|
||||
|
||||
The prompt bodies here are intentionally invalid workflows — prompt_id
|
||||
validation happens before workflow validation, so a rejected id returns
|
||||
``invalid_prompt_id`` while an accepted id falls through to the ordinary
|
||||
workflow-validation error (proving it cleared the id check).
|
||||
"""
|
||||
import requests
|
||||
|
||||
|
||||
def _post_prompt(http: requests.Session, api_base: str, body: dict) -> requests.Response:
|
||||
return http.post(api_base + "/prompt", json=body, timeout=30)
|
||||
|
||||
|
||||
def _error_type(r: requests.Response) -> str:
|
||||
return r.json()["error"]["type"]
|
||||
|
||||
|
||||
def test_non_uuid_prompt_id_rejected(http: requests.Session, api_base: str):
|
||||
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": "not-a-uuid"})
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) == "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_non_string_prompt_id_rejected(http: requests.Session, api_base: str):
|
||||
# Previously str()-coerced (123 became the job id "123"); must now be a 400,
|
||||
# not a 500 from uuid.UUID choking on a non-string.
|
||||
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": 123})
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) == "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_non_canonical_uuid_rejected(http: requests.Session, api_base: str):
|
||||
# Parseable as a UUID, but not the canonical lowercase form: rejected
|
||||
# loudly rather than silently rewritten (downstream lookups match the
|
||||
# stored id exactly).
|
||||
r = _post_prompt(
|
||||
http,
|
||||
api_base,
|
||||
{"prompt": {}, "prompt_id": "AAAAAAAA-BBBB-4CCC-8DDD-EEEEEEEEEEEE"},
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) == "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_canonical_uuid_accepted(http: requests.Session, api_base: str):
|
||||
# The id clears validation; the empty workflow then fails ordinary prompt
|
||||
# validation, proving the request got past the id check.
|
||||
r = _post_prompt(
|
||||
http,
|
||||
api_base,
|
||||
{"prompt": {}, "prompt_id": "aaaaaaaa-bbbb-4ccc-8ddd-eeeeeeeeeeee"},
|
||||
)
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) != "invalid_prompt_id"
|
||||
|
||||
|
||||
def test_null_prompt_id_not_rejected(http: requests.Session, api_base: str):
|
||||
# Explicit null means "server generates" and must not be rejected as an
|
||||
# invalid id. (The minted id itself is not observable here because the
|
||||
# workflow is invalid; unit tests cover validate_job_id directly.)
|
||||
r = _post_prompt(http, api_base, {"prompt": {}, "prompt_id": None})
|
||||
assert r.status_code == 400, r.text
|
||||
assert _error_type(r) != "invalid_prompt_id"
|
||||
@ -95,7 +95,7 @@ def _make_asset(
|
||||
def _ensure_missing_tag(session: Session):
|
||||
"""Ensure the 'missing' tag exists."""
|
||||
if not session.get(Tag, "missing"):
|
||||
session.add(Tag(name="missing", tag_type="system"))
|
||||
session.add(Tag(name="missing"))
|
||||
session.flush()
|
||||
|
||||
|
||||
|
||||
@ -69,8 +69,8 @@ def test_tags_empty_usage(http: requests.Session, api_base: str, asset_factory,
|
||||
used_names = [t["name"] for t in body2["tags"]]
|
||||
assert custom_tag in used_names
|
||||
|
||||
# Hard-delete the asset so the tag usage drops to zero
|
||||
rd = http.delete(f"{api_base}/api/assets/{_asset['id']}?delete_content=true", timeout=120)
|
||||
# Delete the asset reference so the tag usage drops to zero
|
||||
rd = http.delete(f"{api_base}/api/assets/{_asset['id']}", timeout=120)
|
||||
assert rd.status_code == 204
|
||||
|
||||
# Now the custom tag must not be returned when include_zero=false
|
||||
|
||||
93
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
93
tests-unit/comfy_api_test/video_bit_depth_test.py
Normal file
@ -0,0 +1,93 @@
|
||||
import pytest
|
||||
import torch
|
||||
import av
|
||||
import numpy as np
|
||||
from fractions import Fraction
|
||||
from comfy_api.latest._input_impl.video_types import VideoFromFile, VideoFromComponents
|
||||
from comfy_api.latest._util.video_types import VideoComponents
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def gradient_components():
|
||||
"""Narrow horizontal ramp (0.25..0.30) that needs more than 8 bits to stay smooth"""
|
||||
width, height, frames = 64, 64, 3
|
||||
ramp = torch.linspace(0.25, 0.30, width).view(1, 1, width, 1).expand(frames, height, width, 3)
|
||||
return VideoComponents(images=ramp.contiguous(), frame_rate=Fraction(30))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def src8(gradient_components, tmp_path_factory):
|
||||
"""8-bit h264 mp4 (Create Video default)"""
|
||||
path = str(tmp_path_factory.mktemp("video") / "src8.mp4")
|
||||
VideoFromComponents(gradient_components).save_to(path)
|
||||
return path
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def src10(gradient_components, tmp_path_factory):
|
||||
"""10-bit h264 mp4 (Create Video with bit_depth=10)"""
|
||||
path = str(tmp_path_factory.mktemp("video") / "src10.mp4")
|
||||
VideoFromComponents(gradient_components, bit_depth=10).save_to(path)
|
||||
return path
|
||||
|
||||
|
||||
def probe(path):
|
||||
"""(codec, pix_fmt, bit_depth) of the first video stream"""
|
||||
with av.open(path) as container:
|
||||
stream = container.streams.video[0]
|
||||
return (stream.codec.name, stream.format.name, max(c.bits for c in stream.format.components))
|
||||
|
||||
|
||||
def decoded_levels(path):
|
||||
"""Unique tonal levels in the first decoded frame (banding measure)"""
|
||||
with av.open(path) as container:
|
||||
frame = next(container.decode(container.streams.video[0]))
|
||||
return len(np.unique(frame.to_ndarray(format="gbrpf32le")[..., 0]))
|
||||
|
||||
|
||||
def video_packet_bytes(path):
|
||||
"""Raw video packet payloads; identical to the source's only for a true remux"""
|
||||
with av.open(path) as container:
|
||||
return [bytes(p) for p in container.demux(container.streams.video[0]) if p.size]
|
||||
|
||||
|
||||
def test_create_video_bit_depth(src8, src10):
|
||||
"""Create Video's bit_depth picks the encoded depth (default 8-bit); 10-bit reduces banding"""
|
||||
assert probe(src8) == ("h264", "yuv420p", 8)
|
||||
assert probe(src10) == ("h264", "yuv420p10le", 10)
|
||||
assert decoded_levels(src10) > 2 * decoded_levels(src8)
|
||||
|
||||
|
||||
def test_save_auto_keeps_source_depth(src8, src10, tmp_path):
|
||||
"""Save Video (no bit_depth = auto) stream-copies the source, preserving its depth byte-for-byte"""
|
||||
for name, src in [("p8", src8), ("p10", src10)]:
|
||||
path = str(tmp_path / f"{name}.mp4")
|
||||
VideoFromFile(src).save_to(path)
|
||||
assert probe(path) == probe(src)
|
||||
assert video_packet_bytes(path) == video_packet_bytes(src)
|
||||
|
||||
|
||||
def test_save_explicit_depth_reencodes(src8, src10, tmp_path):
|
||||
"""An explicit bit_depth different from the source forces a re-encode to that depth"""
|
||||
down = str(tmp_path / "down8.mp4")
|
||||
VideoFromFile(src10).save_to(down, bit_depth=8)
|
||||
assert probe(down) == ("h264", "yuv420p", 8)
|
||||
|
||||
up = str(tmp_path / "up10.mp4")
|
||||
VideoFromFile(src8).save_to(up, bit_depth=10)
|
||||
assert probe(up) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
|
||||
def test_trim_keeps_source_depth(src10, tmp_path):
|
||||
"""Video Slice re-encodes (trim) but preserves the source's 10-bit depth"""
|
||||
path = str(tmp_path / "trim.mp4")
|
||||
VideoFromFile(src10).as_trimmed(start_time=0, duration=1 / 30, strict_duration=False).save_to(path)
|
||||
assert probe(path) == ("h264", "yuv420p10le", 10)
|
||||
|
||||
|
||||
def test_get_bit_depth(gradient_components, src8, src10):
|
||||
"""get_bit_depth reports a video's depth (backs the Get Video Components output)"""
|
||||
assert VideoFromFile(src8).get_bit_depth() == 8
|
||||
assert VideoFromFile(src10).get_bit_depth() == 10
|
||||
assert VideoFromComponents(gradient_components, bit_depth=10).get_bit_depth() == 10
|
||||
assert VideoFromComponents(gradient_components).get_bit_depth() == 8
|
||||
205
tests-unit/execution_test/test_enrich_output.py
Normal file
205
tests-unit/execution_test/test_enrich_output.py
Normal file
@ -0,0 +1,205 @@
|
||||
"""Tests for enrich_output_with_assets in comfy_execution/asset_enrichment.py."""
|
||||
import os
|
||||
import types
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
|
||||
def _make_args(enable_assets: bool):
|
||||
a = types.SimpleNamespace()
|
||||
a.enable_assets = enable_assets
|
||||
return a
|
||||
|
||||
|
||||
def _make_register_result(ref_id="ref-id-2"):
|
||||
result = MagicMock()
|
||||
result.ref.id = ref_id
|
||||
return result
|
||||
|
||||
|
||||
# Platform-appropriate absolute base. tempfile.gettempdir() returns C:\... on
|
||||
# Windows and /tmp on POSIX, so containment via commonpath behaves naturally.
|
||||
_DEFAULT_BASE = os.path.join(__import__("tempfile").gettempdir(), "asset-enrichment-test-base")
|
||||
|
||||
|
||||
def _mocked_modules(*, enable_assets=True, register_file_in_place=None, directory=_DEFAULT_BASE):
|
||||
return {
|
||||
"comfy.cli_args": MagicMock(args=_make_args(enable_assets)),
|
||||
"folder_paths": MagicMock(get_directory_by_type=MagicMock(return_value=directory)),
|
||||
"app.assets.services.ingest": MagicMock(
|
||||
register_file_in_place=register_file_in_place or MagicMock(return_value=_make_register_result()),
|
||||
DependencyMissingError=type("DependencyMissingError", (Exception,), {}),
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _call(output_ui, *, enable_assets=True, file_exists=True, register_result=None, directory=_DEFAULT_BASE):
|
||||
register_mock = MagicMock(return_value=register_result or _make_register_result())
|
||||
mocked = _mocked_modules(
|
||||
enable_assets=enable_assets,
|
||||
register_file_in_place=register_mock,
|
||||
directory=directory,
|
||||
)
|
||||
|
||||
# Only os.path.isfile is patched — abspath/join must run natively so the
|
||||
# containment check sees real platform paths.
|
||||
with patch.dict("sys.modules", mocked), \
|
||||
patch("os.path.isfile", return_value=file_exists):
|
||||
import importlib
|
||||
import comfy_execution.asset_enrichment as mod
|
||||
importlib.reload(mod)
|
||||
return mod.enrich_output_with_assets(output_ui)
|
||||
|
||||
|
||||
class TestEnrichOutputWithAssets(unittest.TestCase):
|
||||
|
||||
def test_disabled_returns_unchanged(self):
|
||||
output = {"images": [{"filename": "a.png", "subfolder": "", "type": "output"}]}
|
||||
result = _call(output, enable_assets=False)
|
||||
self.assertNotIn("id", result["images"][0])
|
||||
|
||||
def test_non_list_value_passed_through(self):
|
||||
output = {"text": "hello"}
|
||||
result = _call(output)
|
||||
self.assertEqual(result["text"], "hello")
|
||||
|
||||
def test_entry_without_filename_unchanged(self):
|
||||
output = {"latent": [{"subfolder": "", "type": "output"}]}
|
||||
result = _call(output)
|
||||
self.assertNotIn("id", result["latent"][0])
|
||||
|
||||
def test_entry_without_type_unchanged(self):
|
||||
output = {"data": [{"filename": "a.png", "subfolder": ""}]}
|
||||
result = _call(output)
|
||||
self.assertNotIn("id", result["data"][0])
|
||||
|
||||
def test_file_not_on_disk_unchanged(self):
|
||||
output = {"images": [{"filename": "missing.png", "subfolder": "", "type": "output"}]}
|
||||
result = _call(output, file_exists=False)
|
||||
self.assertNotIn("id", result["images"][0])
|
||||
|
||||
def test_unknown_type_returns_none_directory_unchanged(self):
|
||||
output = {"images": [{"filename": "a.png", "subfolder": "", "type": "unknown"}]}
|
||||
result = _call(output, directory=None)
|
||||
self.assertNotIn("id", result["images"][0])
|
||||
|
||||
def test_register_injects_only_id(self):
|
||||
reg = _make_register_result(ref_id="inline-ref")
|
||||
output = {"images": [{"filename": "new.png", "subfolder": "", "type": "output"}]}
|
||||
result = _call(output, register_result=reg)
|
||||
img = result["images"][0]
|
||||
self.assertEqual(img["id"], "inline-ref")
|
||||
# Only id is injected — no asset_hash, name, preview_url, size
|
||||
self.assertNotIn("asset_hash", img)
|
||||
self.assertNotIn("name", img)
|
||||
self.assertNotIn("preview_url", img)
|
||||
self.assertNotIn("size", img)
|
||||
|
||||
def test_register_called_per_entry(self):
|
||||
register_mock = MagicMock(return_value=_make_register_result())
|
||||
mocked = _mocked_modules(register_file_in_place=register_mock)
|
||||
output = {
|
||||
"images": [
|
||||
{"filename": "a.png", "subfolder": "", "type": "output"},
|
||||
{"filename": "b.png", "subfolder": "", "type": "output"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.dict("sys.modules", mocked), \
|
||||
patch("os.path.isfile", return_value=True):
|
||||
import importlib
|
||||
import comfy_execution.asset_enrichment as mod
|
||||
importlib.reload(mod)
|
||||
mod.enrich_output_with_assets(output)
|
||||
|
||||
self.assertEqual(register_mock.call_count, 2)
|
||||
|
||||
def test_original_entry_not_mutated(self):
|
||||
orig = {"filename": "a.png", "subfolder": "", "type": "output"}
|
||||
output = {"images": [orig]}
|
||||
_call(output)
|
||||
self.assertNotIn("id", orig)
|
||||
|
||||
def test_enrichment_error_does_not_block_sibling_entries(self):
|
||||
call_count = [0]
|
||||
good_reg = _make_register_result(ref_id="good-ref")
|
||||
|
||||
def register_side_effect(abs_path, name, tags):
|
||||
call_count[0] += 1
|
||||
if call_count[0] == 1:
|
||||
raise RuntimeError("boom")
|
||||
return good_reg
|
||||
|
||||
mocked = _mocked_modules(register_file_in_place=register_side_effect)
|
||||
|
||||
output = {
|
||||
"images": [
|
||||
{"filename": "bad.png", "subfolder": "", "type": "output"},
|
||||
{"filename": "good.png", "subfolder": "", "type": "output"},
|
||||
]
|
||||
}
|
||||
|
||||
with patch.dict("sys.modules", mocked), \
|
||||
patch("os.path.isfile", return_value=True):
|
||||
import importlib
|
||||
import comfy_execution.asset_enrichment as mod
|
||||
importlib.reload(mod)
|
||||
result = mod.enrich_output_with_assets(output)
|
||||
|
||||
imgs = result["images"]
|
||||
self.assertNotIn("id", imgs[0])
|
||||
self.assertEqual(imgs[1]["id"], "good-ref")
|
||||
|
||||
def test_multiple_output_keys_all_enriched(self):
|
||||
output = {
|
||||
"images": [{"filename": "a.png", "subfolder": "", "type": "output"}],
|
||||
"videos": [{"filename": "b.mp4", "subfolder": "", "type": "output"}],
|
||||
}
|
||||
result = _call(output)
|
||||
self.assertIn("id", result["images"][0])
|
||||
self.assertIn("id", result["videos"][0])
|
||||
|
||||
def test_none_entry_in_list_unchanged(self):
|
||||
output = {"images": [None, {"filename": "a.png", "subfolder": "", "type": "output"}]}
|
||||
result = _call(output)
|
||||
self.assertIsNone(result["images"][0])
|
||||
self.assertIn("id", result["images"][1])
|
||||
|
||||
def test_path_traversal_subfolder_skipped(self):
|
||||
register_mock = MagicMock(return_value=_make_register_result())
|
||||
mocked = _mocked_modules(register_file_in_place=register_mock)
|
||||
|
||||
output = {"images": [{"filename": "passwd", "subfolder": "../../etc", "type": "output"}]}
|
||||
|
||||
# Do NOT patch os.path.abspath — real resolution is required for the containment check.
|
||||
with patch.dict("sys.modules", mocked), \
|
||||
patch("os.path.isfile", return_value=True):
|
||||
import importlib
|
||||
import comfy_execution.asset_enrichment as mod
|
||||
importlib.reload(mod)
|
||||
result = mod.enrich_output_with_assets(output)
|
||||
|
||||
self.assertNotIn("id", result["images"][0])
|
||||
register_mock.assert_not_called()
|
||||
|
||||
def test_absolute_filename_skipped(self):
|
||||
register_mock = MagicMock(return_value=_make_register_result())
|
||||
mocked = _mocked_modules(register_file_in_place=register_mock)
|
||||
|
||||
# Absolute filename — os.path.join discards earlier components when a later one is absolute.
|
||||
absolute_filename = os.path.abspath(os.sep + "etc" + os.sep + "passwd")
|
||||
output = {"images": [{"filename": absolute_filename, "subfolder": "", "type": "output"}]}
|
||||
|
||||
with patch.dict("sys.modules", mocked), \
|
||||
patch("os.path.isfile", return_value=True):
|
||||
import importlib
|
||||
import comfy_execution.asset_enrichment as mod
|
||||
importlib.reload(mod)
|
||||
result = mod.enrich_output_with_assets(output)
|
||||
|
||||
self.assertNotIn("id", result["images"][0])
|
||||
register_mock.assert_not_called()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
@ -1,5 +1,7 @@
|
||||
"""Unit tests for comfy_execution/jobs.py"""
|
||||
|
||||
import pytest
|
||||
|
||||
from comfy_execution.jobs import (
|
||||
JobStatus,
|
||||
is_previewable,
|
||||
@ -10,9 +12,50 @@ from comfy_execution.jobs import (
|
||||
get_outputs_summary,
|
||||
apply_sorting,
|
||||
has_3d_extension,
|
||||
validate_job_id,
|
||||
)
|
||||
|
||||
|
||||
class TestValidateJobId:
|
||||
"""validate_job_id guards job creation: POST /prompt rejects ids it raises on."""
|
||||
|
||||
def test_canonical_form_passes_through(self):
|
||||
cid = "a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7"
|
||||
assert validate_job_id(cid) == cid
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"variant",
|
||||
[
|
||||
"A1B2C3D4-E5F6-7A89-B0C1-D2E3F4A5B6C7", # uppercase
|
||||
"{a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7}", # braced
|
||||
"urn:uuid:a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7", # URN
|
||||
"a1b2c3d4e5f67a89b0c1d2e3f4a5b6c7", # bare hex
|
||||
" a1b2c3d4-e5f6-7a89-b0c1-d2e3f4a5b6c7 ", # padded
|
||||
],
|
||||
)
|
||||
def test_non_canonical_spellings_rejected(self, variant):
|
||||
# uuid.UUID parses all of these, but accepting them would silently
|
||||
# rewrite the client's id (history keys, websocket events, and
|
||||
# /interrupt matching all match the stored form exactly).
|
||||
with pytest.raises(ValueError):
|
||||
validate_job_id(variant)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"bad",
|
||||
["", "not-a-uuid", "prompt-123", "a1b2c3d4-e5f6-7a89-b0c1", "None"],
|
||||
)
|
||||
def test_non_uuid_strings_rejected(self, bad):
|
||||
with pytest.raises(ValueError):
|
||||
validate_job_id(bad)
|
||||
|
||||
@pytest.mark.parametrize("bad", [123, 1.5, True, None, ["a"], {"id": "x"}])
|
||||
def test_non_strings_rejected(self, bad):
|
||||
# uuid.UUID raises AttributeError/TypeError on non-strings; the helper
|
||||
# must normalize those to ValueError so callers need one except clause.
|
||||
with pytest.raises(ValueError):
|
||||
validate_job_id(bad)
|
||||
|
||||
|
||||
class TestJobStatus:
|
||||
"""Test JobStatus constants."""
|
||||
|
||||
|
||||
Loading…
Reference in New Issue
Block a user